From e1673538607a7c8d784013b21b753f0c05c4cc34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Simon?= Date: Tue, 21 Jul 2015 18:26:43 -0400 Subject: Genericize RNNs --- model/rnn.py | 82 ++++++++++++++++++++++++++++++++++-------------------------- 1 file changed, 46 insertions(+), 36 deletions(-) (limited to 'model/rnn.py') diff --git a/model/rnn.py b/model/rnn.py index be17a95..bfc3122 100644 --- a/model/rnn.py +++ b/model/rnn.py @@ -16,16 +16,16 @@ from data.hdf5 import TaxiDataset, TaxiStream import error -class Model(Initializable): +class RNN(Initializable): @lazy() - def __init__(self, config, **kwargs): - super(Model, self).__init__(**kwargs) + def __init__(self, config, rec_input_len=2, output_dim=2, **kwargs): + super(RNN, self).__init__(**kwargs) self.config = config self.pre_context_embedder = ContextEmbedder(config.pre_embedder, name='pre_context_embedder') self.post_context_embedder = ContextEmbedder(config.post_embedder, name='post_context_embedder') - in1 = 2 + sum(x[2] for x in config.pre_embedder.dim_embeddings) + in1 = rec_input_len + sum(x[2] for x in config.pre_embedder.dim_embeddings) self.input_to_rec = MLP(activations=[Tanh()], dims=[in1, config.hidden_state_dim], name='input_to_rec') self.rec = LSTM( @@ -34,7 +34,7 @@ class Model(Initializable): ) in2 = config.hidden_state_dim + sum(x[2] for x in config.post_embedder.dim_embeddings) - self.rec_to_output = MLP(activations=[Tanh()], dims=[in2, 2], name='rec_to_output') + self.rec_to_output = MLP(activations=[Tanh()], dims=[in2, output_dim], name='rec_to_output') self.sequences = ['latitude', 'latitude_mask', 'longitude'] self.context = self.pre_context_embedder.inputs + self.post_context_embedder.inputs @@ -55,59 +55,69 @@ class Model(Initializable): def get_dim(self, name): return self.rec.get_dim(name) - @application - def initial_state(self, *args, **kwargs): - return self.rec.initial_state(*args, **kwargs) + def process_rto(self, rto): + return rto - @recurrent(states=['states', 'cells'], outputs=['destination', 'states', 'cells'], sequences=['latitude', 'longitude', 'latitude_mask']) - def predict_all(self, latitude, longitude, latitude_mask, **kwargs): - latitude = (latitude - data.train_gps_mean[0]) / data.train_gps_std[0] - longitude = (longitude - data.train_gps_mean[1]) / data.train_gps_std[1] + def rec_input(self, latitude, longitude, **kwargs): + return (tensor.shape_padright(latitude), tensor.shape_padright(longitude)) + @recurrent(states=['states', 'cells'], outputs=['destination', 'states', 'cells']) + def predict_all(self, **kwargs): pre_emb = tuple(self.pre_context_embedder.apply(**kwargs)) - latitude = tensor.shape_padright(latitude) - longitude = tensor.shape_padright(longitude) - itr = self.input_to_rec.apply(tensor.concatenate(pre_emb + (latitude, longitude), axis=1)) + + itr_in = tensor.concatenate(pre_emb + self.rec_input(**kwargs), axis=1) + itr = self.input_to_rec.apply(itr_in) itr = itr.repeat(4, axis=1) - (next_states, next_cells) = self.rec.apply(itr, kwargs['states'], kwargs['cells'], mask=latitude_mask, iterate=False) + (next_states, next_cells) = self.rec.apply(itr, kwargs['states'], kwargs['cells'], mask=kwargs['latitude_mask'], iterate=False) post_emb = tuple(self.post_context_embedder.apply(**kwargs)) rto = self.rec_to_output.apply(tensor.concatenate(post_emb + (next_states,), axis=1)) - rto = (rto * data.train_gps_std) + data.train_gps_mean + rto = self.process_rto(rto) return (rto, next_states, next_cells) + @predict_all.property('sequences') + def predict_all_sequences(self): + return self.sequences + + @application(outputs=predict_all.states) + def initial_states(self, *args, **kwargs): + return self.rec.initial_states(*args, **kwargs) + @predict_all.property('contexts') - def predict_all_inputs(self): + def predict_all_context(self): return self.context + def before_predict_all(self, kwargs): + kwargs['latitude'] = (kwargs['latitude'].T - data.train_gps_mean[0]) / data.train_gps_std[0] + kwargs['longitude'] = (kwargs['longitude'].T - data.train_gps_mean[1]) / data.train_gps_std[1] + kwargs['latitude_mask'] = kwargs['latitude_mask'].T + @application(outputs=['destination']) - def predict(self, latitude, longitude, latitude_mask, **kwargs): - latitude = latitude.T - longitude = longitude.T - latitude_mask = latitude_mask.T - res = self.predict_all(latitude, longitude, latitude_mask, **kwargs)[0] - return res[-1] + def predict(self, **kwargs): + self.before_predict_all(kwargs) + res = self.predict_all(**kwargs)[0] + + last_id = tensor.cast(kwargs['latitude_mask'].sum(axis=0) - 1, dtype='int64') + return res[last_id] @predict.property('inputs') def predict_inputs(self): return self.inputs @application(outputs=['cost_matrix']) - def cost_matrix(self, latitude, longitude, latitude_mask, **kwargs): - latitude = latitude.T - longitude = longitude.T - latitude_mask = latitude_mask.T + def cost_matrix(self, **kwargs): + self.before_predict_all(kwargs) - res = self.predict_all(latitude, longitude, latitude_mask, **kwargs)[0] + res = self.predict_all(**kwargs)[0] target = tensor.concatenate( (kwargs['destination_latitude'].dimshuffle('x', 0, 'x'), kwargs['destination_longitude'].dimshuffle('x', 0, 'x')), axis=2) - target = target.repeat(latitude.shape[0], axis=0) + target = target.repeat(kwargs['latitude'].shape[0], axis=0) ce = error.erdist(target.reshape((-1, 2)), res.reshape((-1, 2))) - ce = ce.reshape(latitude.shape) - return ce * latitude_mask + ce = ce.reshape(kwargs['latitude'].shape) + return ce * kwargs['latitude_mask'] @cost_matrix.property('inputs') def cost_matrix_inputs(self): @@ -123,8 +133,8 @@ class Model(Initializable): @application(outputs=['cost']) def valid_cost(self, **kwargs): - # Only works when batch_size is 1. - return self.cost_matrix(**kwargs)[-1,0] + last_id = tensor.cast(kwargs['latitude_mask'].sum(axis=1) - 1, dtype='int64') + return self.cost_matrix(**kwargs)[last_id].mean() @valid_cost.property('inputs') def valid_cost_inputs(self): @@ -158,7 +168,7 @@ class Stream(object): stream = transformers.add_destination(stream) stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) - stream = Batch(stream, iteration_scheme=ConstantScheme(1)) + stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) stream = Padding(stream, mask_sources=['latitude', 'longitude']) stream = transformers.Select(stream, req_vars) return stream @@ -169,7 +179,7 @@ class Stream(object): stream = transformers.taxi_remove_test_only_clients(stream) stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) - stream = Batch(stream, iteration_scheme=ConstantScheme(1)) + stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) stream = Padding(stream, mask_sources=['latitude', 'longitude']) stream = transformers.Select(stream, req_vars) return stream -- cgit v1.2.3