aboutsummaryrefslogtreecommitdiff
path: root/model
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-07-24 11:30:27 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-07-24 11:30:40 -0400
commitff49937eef024916ac4560ce0134d94006e9e2e5 (patch)
tree5b8faba873efc03151ee457c459510c8f64b65f3 /model
parentac49e3cb892e0278ea1d52afdc314322000fae27 (diff)
downloadtaxi-ff49937eef024916ac4560ce0134d94006e9e2e5.tar.gz
taxi-ff49937eef024916ac4560ce0134d94006e9e2e5.zip
RNN & Bidir RNN refactoring (& fixes, maybe)
Diffstat (limited to 'model')
-rw-r--r--model/bidirectional.py157
-rw-r--r--model/bidirectional_tgtcls.py5
-rw-r--r--model/rnn.py63
-rw-r--r--model/stream.py94
4 files changed, 149 insertions, 170 deletions
diff --git a/model/bidirectional.py b/model/bidirectional.py
index ba6440a..4c4ffb0 100644
--- a/model/bidirectional.py
+++ b/model/bidirectional.py
@@ -1,21 +1,33 @@
from theano import tensor
+from toolz import merge
+
from blocks.bricks import application, MLP, Initializable, Linear, Rectifier, Identity
from blocks.bricks.base import lazy
from blocks.bricks.recurrent import Bidirectional, LSTM
from blocks.utils import shared_floatx_zeros
from blocks.bricks.parallel import Fork
-from fuel.transformers import Batch, Padding, Mapping, SortMapping, Unpack, MultiProcessing
-from fuel.streams import DataStream
-from fuel.schemes import ConstantScheme, ShuffledExampleScheme
-
from model import ContextEmbedder
-import data
-from data import transformers
-from data.hdf5 import TaxiDataset, TaxiStream
import error
+import data
+
+from model.stream import StreamRec as Stream
+
+class SegregatedBidirectional(Bidirectional):
+ @application
+ def apply(self, forward_dict, backward_dict):
+ """Applies forward and backward networks and concatenates outputs."""
+
+ forward = self.children[0].apply(as_list=True, **forward_dict)
+ backward = [x[::-1] for x in
+
+ self.children[1].apply(reverse=True, as_list=True,
+ **backward_dict)]
+
+ return [tensor.concatenate([f, b], axis=2)
+ for f, b in zip(forward, backward)]
class BidiRNN(Initializable):
@lazy()
@@ -25,29 +37,36 @@ class BidiRNN(Initializable):
self.context_embedder = ContextEmbedder(config)
- self.rec = Bidirectional(LSTM(dim = config.hidden_state_dim, name = 'recurrent'))
+ self.rec = SegregatedBidirectional(LSTM(dim=config.hidden_state_dim, name='recurrent'))
- self.fork = Fork([name for name in self.rec.prototype.apply.sequences if name!='mask'], prototype=Linear())
+ self.fwd_fork = Fork([name for name in self.rec.prototype.apply.sequences if name!='mask'],
+ prototype=Linear(), name='fwd_fork')
+ self.bkwd_fork = Fork([name for name in self.rec.prototype.apply.sequences if name!='mask'],
+ prototype=Linear(), name='bkwd_fork')
rto_in = config.hidden_state_dim * 2 + sum(x[2] for x in config.dim_embeddings)
- self.rec_to_output = MLP(activations=[Rectifier() for _ in config.dim_hidden] + [Identity()], dims=[rto_in] + config.dim_hidden + [output_dim])
+ self.rec_to_output = MLP(activations=[Rectifier() for _ in config.dim_hidden] + [Identity()],
+ dims=[rto_in] + config.dim_hidden + [output_dim])
self.sequences = ['latitude', 'latitude_mask', 'longitude']
self.inputs = self.sequences + self.context_embedder.inputs
- self.children = [ self.context_embedder, self.fork, self.rec, self.rec_to_output ]
+ self.children = [ self.context_embedder, self.fwd_fork, self.bkwd_fork,
+ self.rec, self.rec_to_output ]
def _push_allocation_config(self):
- self.fork.input_dim = 2
- self.fork.output_dims = [ self.rec.children[0].get_dim(name) for name in self.fork.output_names ]
- self.fork.weights_init = self.config.fork_weights_init
- self.fork.biases_init = self.config.fork_biases_init
- self.rec.weights_init = self.config.rec_weights_init
- self.rec_to_output.weights_init = self.config.mlp_weights_init
- self.rec_to_output.biases_init = self.config.mlp_biases_init
+ for i, fork in enumerate([self.fwd_fork, self.bkwd_fork]):
+ fork.input_dim = 2
+ fork.output_dims = [ self.rec.children[i].get_dim(name)
+ for name in fork.output_names ]
+
+ def _push_initialization_config(self):
+ for brick in [self.fwd_fork, self.bkwd_fork, self.rec, self.rec_to_output]:
+ brick.weights_init = self.config.weights_init
+ brick.biases_init = self.config.biases_init
def process_outputs(self, outputs):
- return outputs
+ pass # must be implemented in child class
@application(outputs=['destination'])
def predict(self, latitude, longitude, latitude_mask, **kwargs):
@@ -55,16 +74,21 @@ class BidiRNN(Initializable):
longitude = (longitude.T - data.train_gps_mean[1]) / data.train_gps_std[1]
latitude_mask = latitude_mask.T
- latitude = tensor.shape_padright(latitude)
- longitude = tensor.shape_padright(longitude)
- rec_in = tensor.concatenate((latitude, longitude), axis=2)
+ rec_in = tensor.concatenate((latitude[:, :, None], longitude[:, :, None]), axis=2)
last_id = tensor.cast(latitude_mask.sum(axis=0) - 1, dtype='int64')
- path = self.rec.apply(self.fork.apply(rec_in), mask=latitude_mask)[0]
+
+ path = self.rec.apply(merge(self.fwd_fork.apply(rec_in, as_dict=True),
+ {'mask': latitude_mask}),
+ merge(self.bkwd_fork.apply(rec_in, as_dict=True),
+ {'mask': latitude_mask}))[0]
+
path_representation = (path[0][:, -self.config.hidden_state_dim:],
- path[last_id - 1, tensor.arange(latitude_mask.shape[1])][:, :self.config.hidden_state_dim])
+ path[last_id - 1, tensor.arange(latitude_mask.shape[1])]
+ [:, :self.config.hidden_state_dim])
- embeddings = tuple(self.context_embedder.apply(**{k: kwargs[k] for k in self.context_embedder.inputs }))
+ embeddings = tuple(self.context_embedder.apply(
+ **{k: kwargs[k] for k in self.context_embedder.inputs }))
inputs = tensor.concatenate(path_representation + embeddings, axis=1)
outputs = self.rec_to_output.apply(inputs)
@@ -87,87 +111,4 @@ class BidiRNN(Initializable):
def cost_inputs(self):
return self.inputs + ['destination_latitude', 'destination_longitude']
-class UniformGenerator(object):
- def __init__(self):
- self.rng = numpy.random.RandomState(123)
- def __call__(self, *args):
- return float(self.rng.uniform())
-
-class Stream(object):
- def __init__(self, config):
- self.config = config
- def train(self, req_vars):
- stream = TaxiDataset('train', data.traintest_ds)
-
- if hasattr(self.config, 'use_cuts_for_training') and self.config.use_cuts_for_training:
- stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme())
- else:
- stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples))
-
- if not data.tvt:
- valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',))
- valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]
- stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)
-
- stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits)
-
- if hasattr(self.config, 'shuffle_batch_size'):
- stream = transformers.Batch(stream, iteration_scheme=ConstantScheme(self.config.shuffle_batch_size))
- stream = Mapping(stream, SortMapping(key=UniformGenerator()))
- stream = Unpack(stream)
-
- stream = transformers.taxi_add_datetime(stream)
- stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
-
- stream = transformers.balanced_batch(stream, key='latitude', batch_size=self.config.batch_size, batch_sort_size=self.config.batch_sort_size)
- stream = Padding(stream, mask_sources=['latitude', 'longitude'])
-
- stream = transformers.Select(stream, req_vars)
- stream = MultiProcessing(stream)
-
- return stream
-
- def valid(self, req_vars):
- stream = TaxiStream(data.valid_set, data.valid_ds)
-
- stream = transformers.taxi_add_datetime(stream)
- stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
-
- stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
- stream = Padding(stream, mask_sources=['latitude', 'longitude'])
- stream = transformers.Select(stream, req_vars)
- return stream
-
- def test(self, req_vars):
- stream = TaxiStream('test', data.traintest_ds)
-
- stream = transformers.taxi_add_datetime(stream)
- stream = transformers.taxi_remove_test_only_clients(stream)
-
- stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
-
- stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
- stream = Padding(stream, mask_sources=['latitude', 'longitude'])
- stream = transformers.Select(stream, req_vars)
- return stream
-
- def inputs(self):
- return {'call_type': tensor.bvector('call_type'),
- 'origin_call': tensor.ivector('origin_call'),
- 'origin_stand': tensor.bvector('origin_stand'),
- 'taxi_id': tensor.wvector('taxi_id'),
- 'timestamp': tensor.ivector('timestamp'),
- 'day_type': tensor.bvector('day_type'),
- 'missing_data': tensor.bvector('missing_data'),
- 'latitude': tensor.matrix('latitude'),
- 'longitude': tensor.matrix('longitude'),
- 'latitude_mask': tensor.matrix('latitude_mask'),
- 'longitude_mask': tensor.matrix('longitude_mask'),
- 'destination_latitude': tensor.vector('destination_latitude'),
- 'destination_longitude': tensor.vector('destination_longitude'),
- 'travel_time': tensor.ivector('travel_time'),
- 'input_time': tensor.ivector('input_time'),
- 'week_of_year': tensor.bvector('week_of_year'),
- 'day_of_week': tensor.bvector('day_of_week'),
- 'qhour_of_day': tensor.bvector('qhour_of_day')}
diff --git a/model/bidirectional_tgtcls.py b/model/bidirectional_tgtcls.py
index 36120f7..4dfbad5 100644
--- a/model/bidirectional_tgtcls.py
+++ b/model/bidirectional_tgtcls.py
@@ -11,9 +11,12 @@ class Model(BidiRNN):
@lazy()
def __init__(self, config, **kwargs):
super(Model, self).__init__(config, output_dim=config.tgtcls.shape[0], **kwargs)
- self.classes = theano.shared(numpy.array(config.tgtcls, dtype=theano.config.floatX), name='classes')
+
+ self.classes = theano.shared(numpy.array(config.tgtcls, dtype=theano.config.floatX),
+ name='classes')
self.softmax = Softmax()
self.children.append(self.softmax)
def process_outputs(self, outputs):
return tensor.dot(self.softmax.apply(outputs), self.classes)
+
diff --git a/model/rnn.py b/model/rnn.py
index b4c6550..1b192ae 100644
--- a/model/rnn.py
+++ b/model/rnn.py
@@ -16,6 +16,8 @@ from data.hdf5 import TaxiDataset, TaxiStream
import error
+from model.stream import StreamRec as Stream
+
class RNN(Initializable):
@lazy()
def __init__(self, config, rec_input_len=2, output_dim=2, **kwargs):
@@ -141,64 +143,3 @@ class RNN(Initializable):
return self.inputs + ['destination_latitude', 'destination_longitude']
-class Stream(object):
- def __init__(self, config):
- self.config = config
-
- def train(self, req_vars):
- stream = TaxiDataset('train', data.traintest_ds)
- stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples))
- if not data.tvt:
- valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',))
- valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]
- stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)
-
- stream = transformers.TaxiExcludeEmptyTrips(stream)
- stream = transformers.taxi_add_datetime(stream)
- if not data.tvt:
- stream = transformers.add_destination(stream)
- stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
-
- stream = transformers.balanced_batch(stream, key='latitude', batch_size=self.config.batch_size, batch_sort_size=self.config.batch_sort_size)
- stream = Padding(stream, mask_sources=['latitude', 'longitude'])
- stream = transformers.Select(stream, req_vars)
- return stream
-
- def valid(self, req_vars):
- stream = TaxiStream(data.valid_set, data.valid_ds)
- stream = transformers.taxi_add_datetime(stream)
- stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
-
- stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
- stream = Padding(stream, mask_sources=['latitude', 'longitude'])
- stream = transformers.Select(stream, req_vars)
- return stream
-
- def test(self, req_vars):
- stream = TaxiStream('test', data.traintest_ds)
- stream = transformers.taxi_add_datetime(stream)
- stream = transformers.taxi_remove_test_only_clients(stream)
- stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
-
- stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
- stream = Padding(stream, mask_sources=['latitude', 'longitude'])
- stream = transformers.Select(stream, req_vars)
- return stream
-
- def inputs(self):
- return {'call_type': tensor.bvector('call_type'),
- 'origin_call': tensor.ivector('origin_call'),
- 'origin_stand': tensor.bvector('origin_stand'),
- 'taxi_id': tensor.wvector('taxi_id'),
- 'timestamp': tensor.ivector('timestamp'),
- 'day_type': tensor.bvector('day_type'),
- 'missing_data': tensor.bvector('missing_data'),
- 'latitude': tensor.matrix('latitude'),
- 'longitude': tensor.matrix('longitude'),
- 'latitude_mask': tensor.matrix('latitude_mask'),
- 'longitude_mask': tensor.matrix('longitude_mask'),
- 'week_of_year': tensor.bvector('week_of_year'),
- 'day_of_week': tensor.bvector('day_of_week'),
- 'qhour_of_day': tensor.bvector('qhour_of_day'),
- 'destination_latitude': tensor.vector('destination_latitude'),
- 'destination_longitude': tensor.vector('destination_longitude')}
diff --git a/model/stream.py b/model/stream.py
new file mode 100644
index 0000000..88b1d7f
--- /dev/null
+++ b/model/stream.py
@@ -0,0 +1,94 @@
+from fuel.transformers import Batch, Padding, Mapping, SortMapping, Unpack, MultiProcessing
+from fuel.streams import DataStream
+from fuel.schemes import ConstantScheme, ShuffledExampleScheme
+
+from theano import tensor
+
+import data
+from data import transformers
+from data.hdf5 import TaxiDataset, TaxiStream
+
+
+class StreamRec(object):
+ def __init__(self, config):
+ self.config = config
+
+ def train(self, req_vars):
+ stream = TaxiDataset('train', data.traintest_ds)
+
+ if hasattr(self.config, 'use_cuts_for_training') and self.config.use_cuts_for_training:
+ stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme())
+ else:
+ stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples))
+
+ if not data.tvt:
+ valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',))
+ valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]
+ stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)
+
+ if hasattr(self.config, 'max_splits'):
+ stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits)
+ elif not data.tvt:
+ stream = transformers.add_destination(stream)
+
+ stream = transformers.TaxiExcludeEmptyTrips(stream)
+ stream = transformers.taxi_add_datetime(stream)
+
+ stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
+
+ stream = transformers.balanced_batch(stream, key='latitude',
+ batch_size=self.config.batch_size,
+ batch_sort_size=self.config.batch_sort_size)
+
+ stream = Padding(stream, mask_sources=['latitude', 'longitude'])
+
+ stream = transformers.Select(stream, req_vars)
+
+ stream = MultiProcessing(stream)
+
+ return stream
+
+ def valid(self, req_vars):
+ stream = TaxiStream(data.valid_set, data.valid_ds)
+
+ stream = transformers.taxi_add_datetime(stream)
+ stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
+
+ stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
+ stream = Padding(stream, mask_sources=['latitude', 'longitude'])
+ stream = transformers.Select(stream, req_vars)
+ return stream
+
+ def test(self, req_vars):
+ stream = TaxiStream('test', data.traintest_ds)
+
+ stream = transformers.taxi_add_datetime(stream)
+ stream = transformers.taxi_remove_test_only_clients(stream)
+
+ stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
+
+ stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
+ stream = Padding(stream, mask_sources=['latitude', 'longitude'])
+ stream = transformers.Select(stream, req_vars)
+ return stream
+
+ def inputs(self):
+ return {'call_type': tensor.bvector('call_type'),
+ 'origin_call': tensor.ivector('origin_call'),
+ 'origin_stand': tensor.bvector('origin_stand'),
+ 'taxi_id': tensor.wvector('taxi_id'),
+ 'timestamp': tensor.ivector('timestamp'),
+ 'day_type': tensor.bvector('day_type'),
+ 'missing_data': tensor.bvector('missing_data'),
+ 'latitude': tensor.matrix('latitude'),
+ 'longitude': tensor.matrix('longitude'),
+ 'latitude_mask': tensor.matrix('latitude_mask'),
+ 'longitude_mask': tensor.matrix('longitude_mask'),
+ 'destination_latitude': tensor.vector('destination_latitude'),
+ 'destination_longitude': tensor.vector('destination_longitude'),
+ 'travel_time': tensor.ivector('travel_time'),
+ 'input_time': tensor.ivector('input_time'),
+ 'week_of_year': tensor.bvector('week_of_year'),
+ 'day_of_week': tensor.bvector('day_of_week'),
+ 'qhour_of_day': tensor.bvector('qhour_of_day')}
+