from theano import tensor import numpy import fuel import blocks from fuel.transformers import Batch, MultiProcessing, Mapping, SortMapping, Unpack from fuel.streams import DataStream from fuel.schemes import ConstantScheme, ShuffledExampleScheme from blocks.bricks import application, MLP, Rectifier, Initializable import data from data import transformers from data.hdf5 import TaxiDataset, TaxiStream from data.cut import TaxiTimeCutScheme from model import ContextEmbedder blocks.config.default_seed = 123 fuel.config.default_seed = 123 class FFMLP(Initializable): def __init__(self, config, output_layer=None, **kwargs): super(FFMLP, self).__init__(**kwargs) self.config = config self.context_embedder = ContextEmbedder(config) output_activation = [] if output_layer is None else [output_layer()] output_dim = [] if output_layer is None else [config.dim_output] self.mlp = MLP(activations=[Rectifier() for _ in config.dim_hidden] + output_activation, dims=[config.dim_input] + config.dim_hidden + output_dim) self.extremities = {'%s_k_%s' % (side, ['latitude', 'longitude'][axis]): axis for side in ['first', 'last'] for axis in [0, 1]} self.inputs = self.context_embedder.inputs + self.extremities.keys() self.children = [ self.context_embedder, self.mlp ] def _push_initialization_config(self): self.mlp.weights_init = self.config.mlp_weights_init self.mlp.biases_init = self.config.mlp_biases_init @application(outputs=['prediction']) def predict(self, **kwargs): embeddings = tuple(self.context_embedder.apply(**{k: kwargs[k] for k in self.context_embedder.inputs })) extremities = tuple((kwargs[k] - data.train_gps_mean[v]) / data.train_gps_std[v] for k, v in self.extremities.items()) inputs = tensor.concatenate(extremities + embeddings, axis=1) outputs = self.mlp.apply(inputs) return outputs @predict.property('inputs') def predict_inputs(self): return self.inputs class UniformGenerator(object): def __init__(self): self.rng = numpy.random.RandomState(123) def __call__(self, *args): return float(self.rng.uniform()) class Stream(object): def __init__(self, config): self.config = config def train(self, req_vars): stream = TaxiDataset('train', data.traintest_ds) if hasattr(self.config, 'use_cuts_for_training') and self.config.use_cuts_for_training: stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme()) else: stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples)) if not data.tvt: valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',)) valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0] stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids) stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits) if hasattr(self.config, 'shuffle_batch_size'): stream = transformers.Batch(stream, iteration_scheme=ConstantScheme(self.config.shuffle_batch_size)) stream = Mapping(stream, SortMapping(key=UniformGenerator())) stream = Unpack(stream) stream = transformers.taxi_add_datetime(stream) stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) stream = transformers.Select(stream, tuple(req_vars)) stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) stream = MultiProcessing(stream) return stream def valid(self, req_vars): stream = TaxiStream(data.valid_set, data.valid_ds) stream = transformers.taxi_add_datetime(stream) stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) stream = transformers.Select(stream, tuple(req_vars)) return Batch(stream, iteration_scheme=ConstantScheme(1000)) def test(self, req_vars): stream = TaxiStream('test', data.traintest_ds) stream = transformers.taxi_add_datetime(stream) stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) stream = transformers.taxi_remove_test_only_clients(stream) return Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) def inputs(self): return {'call_type': tensor.bvector('call_type'), 'origin_call': tensor.ivector('origin_call'), 'origin_stand': tensor.bvector('origin_stand'), 'taxi_id': tensor.wvector('taxi_id'), 'timestamp': tensor.ivector('timestamp'), 'day_type': tensor.bvector('day_type'), 'missing_data': tensor.bvector('missing_data'), 'latitude': tensor.matrix('latitude'), 'longitude': tensor.matrix('longitude'), 'destination_latitude': tensor.vector('destination_latitude'), 'destination_longitude': tensor.vector('destination_longitude'), 'travel_time': tensor.ivector('travel_time'), 'first_k_latitude': tensor.matrix('first_k_latitude'), 'first_k_longitude': tensor.matrix('first_k_longitude'), 'last_k_latitude': tensor.matrix('last_k_latitude'), 'last_k_longitude': tensor.matrix('last_k_longitude'), 'input_time': tensor.ivector('input_time'), 'week_of_year': tensor.bvector('week_of_year'), 'day_of_week': tensor.bvector('day_of_week'), 'qhour_of_day': tensor.bvector('qhour_of_day')}