diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-24 11:30:27 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-24 11:30:40 -0400 |
commit | ff49937eef024916ac4560ce0134d94006e9e2e5 (patch) | |
tree | 5b8faba873efc03151ee457c459510c8f64b65f3 | |
parent | ac49e3cb892e0278ea1d52afdc314322000fae27 (diff) | |
download | taxi-ff49937eef024916ac4560ce0134d94006e9e2e5.tar.gz taxi-ff49937eef024916ac4560ce0134d94006e9e2e5.zip |
RNN & Bidir RNN refactoring (& fixes, maybe)
-rw-r--r-- | config/bidirectional_1.py | 9 | ||||
-rw-r--r-- | config/bidirectional_tgtcls_1.py | 10 | ||||
-rw-r--r-- | model/bidirectional.py | 157 | ||||
-rw-r--r-- | model/bidirectional_tgtcls.py | 5 | ||||
-rw-r--r-- | model/rnn.py | 63 | ||||
-rw-r--r-- | model/stream.py | 94 | ||||
-rwxr-xr-x | train.py | 2 |
7 files changed, 157 insertions, 183 deletions
diff --git a/config/bidirectional_1.py b/config/bidirectional_1.py index 35039d2..6970ece 100644 --- a/config/bidirectional_1.py +++ b/config/bidirectional_1.py @@ -10,7 +10,7 @@ dim_embeddings = [ ('week_of_year', 52, 10), ('day_of_week', 7, 10), ('qhour_of_day', 24 * 4, 10), - ('taxi_id', 448, 10), + ('taxi_id', data.taxi_id_size, 10), ] hidden_state_dim = 100 @@ -18,11 +18,8 @@ hidden_state_dim = 100 dim_hidden = [500, 500] embed_weights_init = IsotropicGaussian(0.01) -fork_weights_init = IsotropicGaussian(0.1) -fork_biases_init = Constant(0.01) -rec_weights_init = IsotropicGaussian(0.1) -mlp_weights_init = IsotropicGaussian(0.1) -mlp_biases_init = Constant(0.01) +weights_init = IsotropicGaussian(0.1) +biases_init = Constant(0.01) batch_size = 20 batch_sort_size = 20 diff --git a/config/bidirectional_tgtcls_1.py b/config/bidirectional_tgtcls_1.py index 88328e4..673c1e4 100644 --- a/config/bidirectional_tgtcls_1.py +++ b/config/bidirectional_tgtcls_1.py @@ -15,7 +15,7 @@ dim_embeddings = [ ('week_of_year', 52, 10), ('day_of_week', 7, 10), ('qhour_of_day', 24 * 4, 10), - ('taxi_id', 448, 10), + ('taxi_id', data.taxi_id_size, 10), ] hidden_state_dim = 100 @@ -23,13 +23,11 @@ hidden_state_dim = 100 dim_hidden = [500, 500] embed_weights_init = IsotropicGaussian(0.01) -fork_weights_init = IsotropicGaussian(0.1) -fork_biases_init = Constant(0.01) -rec_weights_init = IsotropicGaussian(0.1) -mlp_weights_init = IsotropicGaussian(0.1) -mlp_biases_init = Constant(0.01) +weights_init = IsotropicGaussian(0.1) +biases_init = Constant(0.01) batch_size = 20 batch_sort_size = 20 max_splits = 100 + diff --git a/model/bidirectional.py b/model/bidirectional.py index ba6440a..4c4ffb0 100644 --- a/model/bidirectional.py +++ b/model/bidirectional.py @@ -1,21 +1,33 @@ from theano import tensor +from toolz import merge + from blocks.bricks import application, MLP, Initializable, Linear, Rectifier, Identity from blocks.bricks.base import lazy from blocks.bricks.recurrent import Bidirectional, LSTM from blocks.utils import shared_floatx_zeros from blocks.bricks.parallel import Fork -from fuel.transformers import Batch, Padding, Mapping, SortMapping, Unpack, MultiProcessing -from fuel.streams import DataStream -from fuel.schemes import ConstantScheme, ShuffledExampleScheme - from model import ContextEmbedder -import data -from data import transformers -from data.hdf5 import TaxiDataset, TaxiStream import error +import data + +from model.stream import StreamRec as Stream + +class SegregatedBidirectional(Bidirectional): + @application + def apply(self, forward_dict, backward_dict): + """Applies forward and backward networks and concatenates outputs.""" + + forward = self.children[0].apply(as_list=True, **forward_dict) + backward = [x[::-1] for x in + + self.children[1].apply(reverse=True, as_list=True, + **backward_dict)] + + return [tensor.concatenate([f, b], axis=2) + for f, b in zip(forward, backward)] class BidiRNN(Initializable): @lazy() @@ -25,29 +37,36 @@ class BidiRNN(Initializable): self.context_embedder = ContextEmbedder(config) - self.rec = Bidirectional(LSTM(dim = config.hidden_state_dim, name = 'recurrent')) + self.rec = SegregatedBidirectional(LSTM(dim=config.hidden_state_dim, name='recurrent')) - self.fork = Fork([name for name in self.rec.prototype.apply.sequences if name!='mask'], prototype=Linear()) + self.fwd_fork = Fork([name for name in self.rec.prototype.apply.sequences if name!='mask'], + prototype=Linear(), name='fwd_fork') + self.bkwd_fork = Fork([name for name in self.rec.prototype.apply.sequences if name!='mask'], + prototype=Linear(), name='bkwd_fork') rto_in = config.hidden_state_dim * 2 + sum(x[2] for x in config.dim_embeddings) - self.rec_to_output = MLP(activations=[Rectifier() for _ in config.dim_hidden] + [Identity()], dims=[rto_in] + config.dim_hidden + [output_dim]) + self.rec_to_output = MLP(activations=[Rectifier() for _ in config.dim_hidden] + [Identity()], + dims=[rto_in] + config.dim_hidden + [output_dim]) self.sequences = ['latitude', 'latitude_mask', 'longitude'] self.inputs = self.sequences + self.context_embedder.inputs - self.children = [ self.context_embedder, self.fork, self.rec, self.rec_to_output ] + self.children = [ self.context_embedder, self.fwd_fork, self.bkwd_fork, + self.rec, self.rec_to_output ] def _push_allocation_config(self): - self.fork.input_dim = 2 - self.fork.output_dims = [ self.rec.children[0].get_dim(name) for name in self.fork.output_names ] - self.fork.weights_init = self.config.fork_weights_init - self.fork.biases_init = self.config.fork_biases_init - self.rec.weights_init = self.config.rec_weights_init - self.rec_to_output.weights_init = self.config.mlp_weights_init - self.rec_to_output.biases_init = self.config.mlp_biases_init + for i, fork in enumerate([self.fwd_fork, self.bkwd_fork]): + fork.input_dim = 2 + fork.output_dims = [ self.rec.children[i].get_dim(name) + for name in fork.output_names ] + + def _push_initialization_config(self): + for brick in [self.fwd_fork, self.bkwd_fork, self.rec, self.rec_to_output]: + brick.weights_init = self.config.weights_init + brick.biases_init = self.config.biases_init def process_outputs(self, outputs): - return outputs + pass # must be implemented in child class @application(outputs=['destination']) def predict(self, latitude, longitude, latitude_mask, **kwargs): @@ -55,16 +74,21 @@ class BidiRNN(Initializable): longitude = (longitude.T - data.train_gps_mean[1]) / data.train_gps_std[1] latitude_mask = latitude_mask.T - latitude = tensor.shape_padright(latitude) - longitude = tensor.shape_padright(longitude) - rec_in = tensor.concatenate((latitude, longitude), axis=2) + rec_in = tensor.concatenate((latitude[:, :, None], longitude[:, :, None]), axis=2) last_id = tensor.cast(latitude_mask.sum(axis=0) - 1, dtype='int64') - path = self.rec.apply(self.fork.apply(rec_in), mask=latitude_mask)[0] + + path = self.rec.apply(merge(self.fwd_fork.apply(rec_in, as_dict=True), + {'mask': latitude_mask}), + merge(self.bkwd_fork.apply(rec_in, as_dict=True), + {'mask': latitude_mask}))[0] + path_representation = (path[0][:, -self.config.hidden_state_dim:], - path[last_id - 1, tensor.arange(latitude_mask.shape[1])][:, :self.config.hidden_state_dim]) + path[last_id - 1, tensor.arange(latitude_mask.shape[1])] + [:, :self.config.hidden_state_dim]) - embeddings = tuple(self.context_embedder.apply(**{k: kwargs[k] for k in self.context_embedder.inputs })) + embeddings = tuple(self.context_embedder.apply( + **{k: kwargs[k] for k in self.context_embedder.inputs })) inputs = tensor.concatenate(path_representation + embeddings, axis=1) outputs = self.rec_to_output.apply(inputs) @@ -87,87 +111,4 @@ class BidiRNN(Initializable): def cost_inputs(self): return self.inputs + ['destination_latitude', 'destination_longitude'] -class UniformGenerator(object): - def __init__(self): - self.rng = numpy.random.RandomState(123) - def __call__(self, *args): - return float(self.rng.uniform()) - -class Stream(object): - def __init__(self, config): - self.config = config - def train(self, req_vars): - stream = TaxiDataset('train', data.traintest_ds) - - if hasattr(self.config, 'use_cuts_for_training') and self.config.use_cuts_for_training: - stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme()) - else: - stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples)) - - if not data.tvt: - valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',)) - valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0] - stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids) - - stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits) - - if hasattr(self.config, 'shuffle_batch_size'): - stream = transformers.Batch(stream, iteration_scheme=ConstantScheme(self.config.shuffle_batch_size)) - stream = Mapping(stream, SortMapping(key=UniformGenerator())) - stream = Unpack(stream) - - stream = transformers.taxi_add_datetime(stream) - stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) - - stream = transformers.balanced_batch(stream, key='latitude', batch_size=self.config.batch_size, batch_sort_size=self.config.batch_sort_size) - stream = Padding(stream, mask_sources=['latitude', 'longitude']) - - stream = transformers.Select(stream, req_vars) - stream = MultiProcessing(stream) - - return stream - - def valid(self, req_vars): - stream = TaxiStream(data.valid_set, data.valid_ds) - - stream = transformers.taxi_add_datetime(stream) - stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) - - stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) - stream = Padding(stream, mask_sources=['latitude', 'longitude']) - stream = transformers.Select(stream, req_vars) - return stream - - def test(self, req_vars): - stream = TaxiStream('test', data.traintest_ds) - - stream = transformers.taxi_add_datetime(stream) - stream = transformers.taxi_remove_test_only_clients(stream) - - stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) - - stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) - stream = Padding(stream, mask_sources=['latitude', 'longitude']) - stream = transformers.Select(stream, req_vars) - return stream - - def inputs(self): - return {'call_type': tensor.bvector('call_type'), - 'origin_call': tensor.ivector('origin_call'), - 'origin_stand': tensor.bvector('origin_stand'), - 'taxi_id': tensor.wvector('taxi_id'), - 'timestamp': tensor.ivector('timestamp'), - 'day_type': tensor.bvector('day_type'), - 'missing_data': tensor.bvector('missing_data'), - 'latitude': tensor.matrix('latitude'), - 'longitude': tensor.matrix('longitude'), - 'latitude_mask': tensor.matrix('latitude_mask'), - 'longitude_mask': tensor.matrix('longitude_mask'), - 'destination_latitude': tensor.vector('destination_latitude'), - 'destination_longitude': tensor.vector('destination_longitude'), - 'travel_time': tensor.ivector('travel_time'), - 'input_time': tensor.ivector('input_time'), - 'week_of_year': tensor.bvector('week_of_year'), - 'day_of_week': tensor.bvector('day_of_week'), - 'qhour_of_day': tensor.bvector('qhour_of_day')} diff --git a/model/bidirectional_tgtcls.py b/model/bidirectional_tgtcls.py index 36120f7..4dfbad5 100644 --- a/model/bidirectional_tgtcls.py +++ b/model/bidirectional_tgtcls.py @@ -11,9 +11,12 @@ class Model(BidiRNN): @lazy() def __init__(self, config, **kwargs): super(Model, self).__init__(config, output_dim=config.tgtcls.shape[0], **kwargs) - self.classes = theano.shared(numpy.array(config.tgtcls, dtype=theano.config.floatX), name='classes') + + self.classes = theano.shared(numpy.array(config.tgtcls, dtype=theano.config.floatX), + name='classes') self.softmax = Softmax() self.children.append(self.softmax) def process_outputs(self, outputs): return tensor.dot(self.softmax.apply(outputs), self.classes) + diff --git a/model/rnn.py b/model/rnn.py index b4c6550..1b192ae 100644 --- a/model/rnn.py +++ b/model/rnn.py @@ -16,6 +16,8 @@ from data.hdf5 import TaxiDataset, TaxiStream import error +from model.stream import StreamRec as Stream + class RNN(Initializable): @lazy() def __init__(self, config, rec_input_len=2, output_dim=2, **kwargs): @@ -141,64 +143,3 @@ class RNN(Initializable): return self.inputs + ['destination_latitude', 'destination_longitude'] -class Stream(object): - def __init__(self, config): - self.config = config - - def train(self, req_vars): - stream = TaxiDataset('train', data.traintest_ds) - stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples)) - if not data.tvt: - valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',)) - valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0] - stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids) - - stream = transformers.TaxiExcludeEmptyTrips(stream) - stream = transformers.taxi_add_datetime(stream) - if not data.tvt: - stream = transformers.add_destination(stream) - stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) - - stream = transformers.balanced_batch(stream, key='latitude', batch_size=self.config.batch_size, batch_sort_size=self.config.batch_sort_size) - stream = Padding(stream, mask_sources=['latitude', 'longitude']) - stream = transformers.Select(stream, req_vars) - return stream - - def valid(self, req_vars): - stream = TaxiStream(data.valid_set, data.valid_ds) - stream = transformers.taxi_add_datetime(stream) - stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) - - stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) - stream = Padding(stream, mask_sources=['latitude', 'longitude']) - stream = transformers.Select(stream, req_vars) - return stream - - def test(self, req_vars): - stream = TaxiStream('test', data.traintest_ds) - stream = transformers.taxi_add_datetime(stream) - stream = transformers.taxi_remove_test_only_clients(stream) - stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) - - stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) - stream = Padding(stream, mask_sources=['latitude', 'longitude']) - stream = transformers.Select(stream, req_vars) - return stream - - def inputs(self): - return {'call_type': tensor.bvector('call_type'), - 'origin_call': tensor.ivector('origin_call'), - 'origin_stand': tensor.bvector('origin_stand'), - 'taxi_id': tensor.wvector('taxi_id'), - 'timestamp': tensor.ivector('timestamp'), - 'day_type': tensor.bvector('day_type'), - 'missing_data': tensor.bvector('missing_data'), - 'latitude': tensor.matrix('latitude'), - 'longitude': tensor.matrix('longitude'), - 'latitude_mask': tensor.matrix('latitude_mask'), - 'longitude_mask': tensor.matrix('longitude_mask'), - 'week_of_year': tensor.bvector('week_of_year'), - 'day_of_week': tensor.bvector('day_of_week'), - 'qhour_of_day': tensor.bvector('qhour_of_day'), - 'destination_latitude': tensor.vector('destination_latitude'), - 'destination_longitude': tensor.vector('destination_longitude')} diff --git a/model/stream.py b/model/stream.py new file mode 100644 index 0000000..88b1d7f --- /dev/null +++ b/model/stream.py @@ -0,0 +1,94 @@ +from fuel.transformers import Batch, Padding, Mapping, SortMapping, Unpack, MultiProcessing +from fuel.streams import DataStream +from fuel.schemes import ConstantScheme, ShuffledExampleScheme + +from theano import tensor + +import data +from data import transformers +from data.hdf5 import TaxiDataset, TaxiStream + + +class StreamRec(object): + def __init__(self, config): + self.config = config + + def train(self, req_vars): + stream = TaxiDataset('train', data.traintest_ds) + + if hasattr(self.config, 'use_cuts_for_training') and self.config.use_cuts_for_training: + stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme()) + else: + stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples)) + + if not data.tvt: + valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',)) + valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0] + stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids) + + if hasattr(self.config, 'max_splits'): + stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits) + elif not data.tvt: + stream = transformers.add_destination(stream) + + stream = transformers.TaxiExcludeEmptyTrips(stream) + stream = transformers.taxi_add_datetime(stream) + + stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) + + stream = transformers.balanced_batch(stream, key='latitude', + batch_size=self.config.batch_size, + batch_sort_size=self.config.batch_sort_size) + + stream = Padding(stream, mask_sources=['latitude', 'longitude']) + + stream = transformers.Select(stream, req_vars) + + stream = MultiProcessing(stream) + + return stream + + def valid(self, req_vars): + stream = TaxiStream(data.valid_set, data.valid_ds) + + stream = transformers.taxi_add_datetime(stream) + stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) + + stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) + stream = Padding(stream, mask_sources=['latitude', 'longitude']) + stream = transformers.Select(stream, req_vars) + return stream + + def test(self, req_vars): + stream = TaxiStream('test', data.traintest_ds) + + stream = transformers.taxi_add_datetime(stream) + stream = transformers.taxi_remove_test_only_clients(stream) + + stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) + + stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) + stream = Padding(stream, mask_sources=['latitude', 'longitude']) + stream = transformers.Select(stream, req_vars) + return stream + + def inputs(self): + return {'call_type': tensor.bvector('call_type'), + 'origin_call': tensor.ivector('origin_call'), + 'origin_stand': tensor.bvector('origin_stand'), + 'taxi_id': tensor.wvector('taxi_id'), + 'timestamp': tensor.ivector('timestamp'), + 'day_type': tensor.bvector('day_type'), + 'missing_data': tensor.bvector('missing_data'), + 'latitude': tensor.matrix('latitude'), + 'longitude': tensor.matrix('longitude'), + 'latitude_mask': tensor.matrix('latitude_mask'), + 'longitude_mask': tensor.matrix('longitude_mask'), + 'destination_latitude': tensor.vector('destination_latitude'), + 'destination_longitude': tensor.vector('destination_longitude'), + 'travel_time': tensor.ivector('travel_time'), + 'input_time': tensor.ivector('input_time'), + 'week_of_year': tensor.bvector('week_of_year'), + 'day_of_week': tensor.bvector('day_of_week'), + 'qhour_of_day': tensor.bvector('qhour_of_day')} + @@ -110,7 +110,7 @@ if __name__ == "__main__": DataStreamMonitoring(valid_monitored, valid_stream, prefix='valid', every_n_batches=10000), - Printing(every_n_batches=1000), + Printing(every_n_batches=10000), FinishAfter(every_n_batches=10000000), SaveLoadParams(dump_path, cg, |