from fuel.transformers import Batch, Padding, Mapping, SortMapping, Unpack, MultiProcessing from fuel.streams import DataStream from fuel.schemes import ConstantScheme, ShuffledExampleScheme from theano import tensor import data from data import transformers from data.hdf5 import TaxiDataset, TaxiStream class StreamRec(object): def __init__(self, config): self.config = config def train(self, req_vars): stream = TaxiDataset('train', data.traintest_ds) if hasattr(self.config, 'use_cuts_for_training') and self.config.use_cuts_for_training: stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme()) else: stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples)) if not data.tvt: valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',)) valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0] stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids) if hasattr(self.config, 'max_splits'): stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits) elif not data.tvt: stream = transformers.add_destination(stream) stream = transformers.TaxiExcludeEmptyTrips(stream) stream = transformers.taxi_add_datetime(stream) stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) stream = transformers.balanced_batch(stream, key='latitude', batch_size=self.config.batch_size, batch_sort_size=self.config.batch_sort_size) stream = Padding(stream, mask_sources=['latitude', 'longitude']) stream = transformers.Select(stream, req_vars) stream = MultiProcessing(stream) return stream def valid(self, req_vars): stream = TaxiStream(data.valid_set, data.valid_ds) stream = transformers.taxi_add_datetime(stream) stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) stream = transformers.balanced_batch(stream, key='latitude', batch_size=self.config.batch_size, batch_sort_size=self.config.batch_sort_size) stream = Padding(stream, mask_sources=['latitude', 'longitude']) stream = transformers.Select(stream, req_vars) stream = MultiProcessing(stream) return stream def test(self, req_vars): stream = TaxiStream('test', data.traintest_ds) stream = transformers.taxi_add_datetime(stream) stream = transformers.taxi_remove_test_only_clients(stream) stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) stream = Padding(stream, mask_sources=['latitude', 'longitude']) stream = transformers.Select(stream, req_vars) return stream def inputs(self): return {'call_type': tensor.bvector('call_type'), 'origin_call': tensor.ivector('origin_call'), 'origin_stand': tensor.bvector('origin_stand'), 'taxi_id': tensor.wvector('taxi_id'), 'timestamp': tensor.ivector('timestamp'), 'day_type': tensor.bvector('day_type'), 'missing_data': tensor.bvector('missing_data'), 'latitude': tensor.matrix('latitude'), 'longitude': tensor.matrix('longitude'), 'latitude_mask': tensor.matrix('latitude_mask'), 'longitude_mask': tensor.matrix('longitude_mask'), 'destination_latitude': tensor.vector('destination_latitude'), 'destination_longitude': tensor.vector('destination_longitude'), 'travel_time': tensor.ivector('travel_time'), 'input_time': tensor.ivector('input_time'), 'week_of_year': tensor.bvector('week_of_year'), 'day_of_week': tensor.bvector('day_of_week'), 'qhour_of_day': tensor.bvector('qhour_of_day')}