aboutsummaryrefslogblamecommitdiff
path: root/model/bidirectional_tgtcls_window.py
blob: 10693ffd4a9c9fbec636b36f0797c760c4e5d889 (plain) (tree)





























































































































































































                                                                                                      
from model.bidirectional import SegregatedBidirectional


class Model(Initializable):
    @lazy()
    def __init__(self, config, output_dim=2, **kwargs):
        super(Model, self).__init__(**kwargs)
        self.config = config

        self.context_embedder = ContextEmbedder(config)
        
        act = config.rec_activation() if hasattr(config, 'rec_activation') else None
        self.rec = SegregatedBidirectional(LSTM(dim=config.hidden_state_dim, activation=act,
                                                name='recurrent'))

        self.fwd_fork = Fork([name for name in self.rec.prototype.apply.sequences if name!='mask'],
                             prototype=Linear(), name='fwd_fork')
        self.bkwd_fork = Fork([name for name in self.rec.prototype.apply.sequences if name!='mask'],
                              prototype=Linear(), name='bkwd_fork')

        rto_in = config.hidden_state_dim * 2 + sum(x[2] for x in config.dim_embeddings)
        self.rec_to_output = MLP(activations=[Rectifier() for _ in config.dim_hidden] + [Identity()], 
                                 dims=[rto_in] + config.dim_hidden + [output_dim])

        self.softmax = Softmax()

        self.sequences = ['latitude', 'latitude_mask', 'longitude']
        self.inputs = self.sequences + self.context_embedder.inputs

        self.children = [ self.context_embedder, self.fwd_fork, self.bkwd_fork,
                          self.rec, self.rec_to_output, self.softmax ]

        self.classes = theano.shared(numpy.array(config.tgtcls, dtype=theano.config.floatX),
                                     name='classes')

    def _push_allocation_config(self):
        for i, fork in enumerate([self.fwd_fork, self.bkwd_fork]):
            fork.input_dim = 2 * self.config.window_size
            fork.output_dims = [ self.rec.children[i].get_dim(name)
                                 for name in fork.output_names ]

    def _push_initialization_config(self):
        for brick in [self.fwd_fork, self.bkwd_fork, self.rec, self.rec_to_output]:
            brick.weights_init = self.config.weights_init
            brick.biases_init = self.config.biases_init

    def process_outputs(self, outputs):
        return tensor.dot(self.softmax.apply(outputs), self.classes)

    @application(outputs=['destination'])
    def predict(self, latitude, longitude, latitude_mask, **kwargs):
        latitude = (latitude.dimshuffle(1, 0, 2) - data.train_gps_mean[0]) / data.train_gps_std[0]
        longitude = (longitude.dimshuffle(1, 0, 2) - data.train_gps_mean[1]) / data.train_gps_std[1]
        latitude_mask = latitude_mask.T

        rec_in = tensor.concatenate((latitude, longitude), axis=2)

        last_id = tensor.cast(latitude_mask.sum(axis=0) - 1, dtype='int64')

        path = self.rec.apply(merge(self.fwd_fork.apply(rec_in, as_dict=True),
                                    {'mask': latitude_mask}),
                              merge(self.bkwd_fork.apply(rec_in, as_dict=True),
                                    {'mask': latitude_mask}))[0]

        path_representation = (path[0][:, -self.config.hidden_state_dim:],
                               path[last_id - 1, tensor.arange(latitude_mask.shape[1])]
                                   [:, :self.config.hidden_state_dim])

        embeddings = tuple(self.context_embedder.apply(
                        **{k: kwargs[k] for k in self.context_embedder.inputs }))

        inputs = tensor.concatenate(path_representation + embeddings, axis=1)
        outputs = self.rec_to_output.apply(inputs)

        return self.process_outputs(outputs)

    @predict.property('inputs')
    def predict_inputs(self):
        return self.inputs

    @application(outputs=['cost'])
    def cost(self, **kwargs):
        y_hat = self.predict(**kwargs)
        y = tensor.concatenate((kwargs['destination_latitude'][:, None],
                                kwargs['destination_longitude'][:, None]), axis=1)

        return error.erdist(y_hat, y).mean()

    @cost.property('inputs')
    def cost_inputs(self):
        return self.inputs + ['destination_latitude', 'destination_longitude']



class Stream(object):
    def __init__(self, config):
        self.config = config

    def train(self, req_vars):
        stream = TaxiDataset('train', data.traintest_ds)

        if hasattr(self.config, 'use_cuts_for_training') and self.config.use_cuts_for_training:
            stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme())
        else:
            stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples))

        if not data.tvt:
            valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',))
            valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]
            stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)

        if hasattr(self.config, 'max_splits'):
            stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits)
        elif not data.tvt:
            stream = transformers.add_destination(stream)

        if hasattr(self.config, 'train_max_len'):
            idx = stream.sources.index('latitude')
            def max_len_filter(x):
                return len(x[idx]) <= self.config.train_max_len
            stream = Filter(stream, max_len_filter)

        stream = transformers.TaxiExcludeEmptyTrips(stream)

        stream = transformers.window(stream, config.window_size)
        
        stream = transformers.taxi_add_datetime(stream)
        stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))

        stream = transformers.balanced_batch(stream, key='latitude',
                                             batch_size=self.config.batch_size,
                                             batch_sort_size=self.config.batch_sort_size)
        stream = Padding(stream, mask_sources=['latitude', 'longitude'])
        stream = transformers.Select(stream, req_vars)
        stream = MultiProcessing(stream)

        return stream

    def valid(self, req_vars):
        stream = TaxiStream(data.valid_set, data.valid_ds)

        stream = transformers.window(stream, config.window_size)

        stream = transformers.taxi_add_datetime(stream)
        stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))

        stream = transformers.balanced_batch(stream, key='latitude',
                                             batch_size=self.config.batch_size,
                                             batch_sort_size=self.config.batch_sort_size)
        stream = Padding(stream, mask_sources=['latitude', 'longitude'])
        stream = transformers.Select(stream, req_vars)
        stream = MultiProcessing(stream)

        return stream

    def test(self, req_vars):
        stream = TaxiStream('test', data.traintest_ds)

        stream = transformers.window(stream, config.window_size)
        
        stream = transformers.taxi_add_datetime(stream)
        stream = transformers.taxi_remove_test_only_clients(stream)

        stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))

        stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
        stream = Padding(stream, mask_sources=['latitude', 'longitude'])
        stream = transformers.Select(stream, req_vars)
        return stream

    def inputs(self):
        return {'call_type': tensor.bvector('call_type'),
                'origin_call': tensor.ivector('origin_call'),
                'origin_stand': tensor.bvector('origin_stand'),
                'taxi_id': tensor.wvector('taxi_id'),
                'timestamp': tensor.ivector('timestamp'),
                'day_type': tensor.bvector('day_type'),
                'missing_data': tensor.bvector('missing_data'),
                'latitude': tensor.tensor('latitude'),
                'longitude': tensor.tensor('longitude'),
                'latitude_mask': tensor.matrix('latitude_mask'),
                'longitude_mask': tensor.matrix('longitude_mask'),
                'destination_latitude': tensor.vector('destination_latitude'),
                'destination_longitude': tensor.vector('destination_longitude'),
                'travel_time': tensor.ivector('travel_time'),
                'input_time': tensor.ivector('input_time'),
                'week_of_year': tensor.bvector('week_of_year'),
                'day_of_week': tensor.bvector('day_of_week'),
                'qhour_of_day': tensor.bvector('qhour_of_day')}