aboutsummaryrefslogtreecommitdiff
path: root/model/rnn.py
diff options
context:
space:
mode:
authorÉtienne Simon <esimon@esimon.eu>2015-07-21 18:26:43 -0400
committerÉtienne Simon <esimon@esimon.eu>2015-07-21 18:27:55 -0400
commite1673538607a7c8d784013b21b753f0c05c4cc34 (patch)
treef42e316e0c5bf67e3c9953aad6ba8fe9656829f2 /model/rnn.py
parent58dcf7b17e9db6af53808994a7d39a759fcc5028 (diff)
downloadtaxi-e1673538607a7c8d784013b21b753f0c05c4cc34.tar.gz
taxi-e1673538607a7c8d784013b21b753f0c05c4cc34.zip
Genericize RNNs
Diffstat (limited to 'model/rnn.py')
-rw-r--r--model/rnn.py82
1 files changed, 46 insertions, 36 deletions
diff --git a/model/rnn.py b/model/rnn.py
index be17a95..bfc3122 100644
--- a/model/rnn.py
+++ b/model/rnn.py
@@ -16,16 +16,16 @@ from data.hdf5 import TaxiDataset, TaxiStream
import error
-class Model(Initializable):
+class RNN(Initializable):
@lazy()
- def __init__(self, config, **kwargs):
- super(Model, self).__init__(**kwargs)
+ def __init__(self, config, rec_input_len=2, output_dim=2, **kwargs):
+ super(RNN, self).__init__(**kwargs)
self.config = config
self.pre_context_embedder = ContextEmbedder(config.pre_embedder, name='pre_context_embedder')
self.post_context_embedder = ContextEmbedder(config.post_embedder, name='post_context_embedder')
- in1 = 2 + sum(x[2] for x in config.pre_embedder.dim_embeddings)
+ in1 = rec_input_len + sum(x[2] for x in config.pre_embedder.dim_embeddings)
self.input_to_rec = MLP(activations=[Tanh()], dims=[in1, config.hidden_state_dim], name='input_to_rec')
self.rec = LSTM(
@@ -34,7 +34,7 @@ class Model(Initializable):
)
in2 = config.hidden_state_dim + sum(x[2] for x in config.post_embedder.dim_embeddings)
- self.rec_to_output = MLP(activations=[Tanh()], dims=[in2, 2], name='rec_to_output')
+ self.rec_to_output = MLP(activations=[Tanh()], dims=[in2, output_dim], name='rec_to_output')
self.sequences = ['latitude', 'latitude_mask', 'longitude']
self.context = self.pre_context_embedder.inputs + self.post_context_embedder.inputs
@@ -55,59 +55,69 @@ class Model(Initializable):
def get_dim(self, name):
return self.rec.get_dim(name)
- @application
- def initial_state(self, *args, **kwargs):
- return self.rec.initial_state(*args, **kwargs)
+ def process_rto(self, rto):
+ return rto
- @recurrent(states=['states', 'cells'], outputs=['destination', 'states', 'cells'], sequences=['latitude', 'longitude', 'latitude_mask'])
- def predict_all(self, latitude, longitude, latitude_mask, **kwargs):
- latitude = (latitude - data.train_gps_mean[0]) / data.train_gps_std[0]
- longitude = (longitude - data.train_gps_mean[1]) / data.train_gps_std[1]
+ def rec_input(self, latitude, longitude, **kwargs):
+ return (tensor.shape_padright(latitude), tensor.shape_padright(longitude))
+ @recurrent(states=['states', 'cells'], outputs=['destination', 'states', 'cells'])
+ def predict_all(self, **kwargs):
pre_emb = tuple(self.pre_context_embedder.apply(**kwargs))
- latitude = tensor.shape_padright(latitude)
- longitude = tensor.shape_padright(longitude)
- itr = self.input_to_rec.apply(tensor.concatenate(pre_emb + (latitude, longitude), axis=1))
+
+ itr_in = tensor.concatenate(pre_emb + self.rec_input(**kwargs), axis=1)
+ itr = self.input_to_rec.apply(itr_in)
itr = itr.repeat(4, axis=1)
- (next_states, next_cells) = self.rec.apply(itr, kwargs['states'], kwargs['cells'], mask=latitude_mask, iterate=False)
+ (next_states, next_cells) = self.rec.apply(itr, kwargs['states'], kwargs['cells'], mask=kwargs['latitude_mask'], iterate=False)
post_emb = tuple(self.post_context_embedder.apply(**kwargs))
rto = self.rec_to_output.apply(tensor.concatenate(post_emb + (next_states,), axis=1))
- rto = (rto * data.train_gps_std) + data.train_gps_mean
+ rto = self.process_rto(rto)
return (rto, next_states, next_cells)
+ @predict_all.property('sequences')
+ def predict_all_sequences(self):
+ return self.sequences
+
+ @application(outputs=predict_all.states)
+ def initial_states(self, *args, **kwargs):
+ return self.rec.initial_states(*args, **kwargs)
+
@predict_all.property('contexts')
- def predict_all_inputs(self):
+ def predict_all_context(self):
return self.context
+ def before_predict_all(self, kwargs):
+ kwargs['latitude'] = (kwargs['latitude'].T - data.train_gps_mean[0]) / data.train_gps_std[0]
+ kwargs['longitude'] = (kwargs['longitude'].T - data.train_gps_mean[1]) / data.train_gps_std[1]
+ kwargs['latitude_mask'] = kwargs['latitude_mask'].T
+
@application(outputs=['destination'])
- def predict(self, latitude, longitude, latitude_mask, **kwargs):
- latitude = latitude.T
- longitude = longitude.T
- latitude_mask = latitude_mask.T
- res = self.predict_all(latitude, longitude, latitude_mask, **kwargs)[0]
- return res[-1]
+ def predict(self, **kwargs):
+ self.before_predict_all(kwargs)
+ res = self.predict_all(**kwargs)[0]
+
+ last_id = tensor.cast(kwargs['latitude_mask'].sum(axis=0) - 1, dtype='int64')
+ return res[last_id]
@predict.property('inputs')
def predict_inputs(self):
return self.inputs
@application(outputs=['cost_matrix'])
- def cost_matrix(self, latitude, longitude, latitude_mask, **kwargs):
- latitude = latitude.T
- longitude = longitude.T
- latitude_mask = latitude_mask.T
+ def cost_matrix(self, **kwargs):
+ self.before_predict_all(kwargs)
- res = self.predict_all(latitude, longitude, latitude_mask, **kwargs)[0]
+ res = self.predict_all(**kwargs)[0]
target = tensor.concatenate(
(kwargs['destination_latitude'].dimshuffle('x', 0, 'x'),
kwargs['destination_longitude'].dimshuffle('x', 0, 'x')),
axis=2)
- target = target.repeat(latitude.shape[0], axis=0)
+ target = target.repeat(kwargs['latitude'].shape[0], axis=0)
ce = error.erdist(target.reshape((-1, 2)), res.reshape((-1, 2)))
- ce = ce.reshape(latitude.shape)
- return ce * latitude_mask
+ ce = ce.reshape(kwargs['latitude'].shape)
+ return ce * kwargs['latitude_mask']
@cost_matrix.property('inputs')
def cost_matrix_inputs(self):
@@ -123,8 +133,8 @@ class Model(Initializable):
@application(outputs=['cost'])
def valid_cost(self, **kwargs):
- # Only works when batch_size is 1.
- return self.cost_matrix(**kwargs)[-1,0]
+ last_id = tensor.cast(kwargs['latitude_mask'].sum(axis=1) - 1, dtype='int64')
+ return self.cost_matrix(**kwargs)[last_id].mean()
@valid_cost.property('inputs')
def valid_cost_inputs(self):
@@ -158,7 +168,7 @@ class Stream(object):
stream = transformers.add_destination(stream)
stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
- stream = Batch(stream, iteration_scheme=ConstantScheme(1))
+ stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
stream = Padding(stream, mask_sources=['latitude', 'longitude'])
stream = transformers.Select(stream, req_vars)
return stream
@@ -169,7 +179,7 @@ class Stream(object):
stream = transformers.taxi_remove_test_only_clients(stream)
stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
- stream = Batch(stream, iteration_scheme=ConstantScheme(1))
+ stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
stream = Padding(stream, mask_sources=['latitude', 'longitude'])
stream = transformers.Select(stream, req_vars)
return stream