diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-24 11:30:27 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-24 11:30:40 -0400 |
commit | ff49937eef024916ac4560ce0134d94006e9e2e5 (patch) | |
tree | 5b8faba873efc03151ee457c459510c8f64b65f3 /model/stream.py | |
parent | ac49e3cb892e0278ea1d52afdc314322000fae27 (diff) | |
download | taxi-ff49937eef024916ac4560ce0134d94006e9e2e5.tar.gz taxi-ff49937eef024916ac4560ce0134d94006e9e2e5.zip |
RNN & Bidir RNN refactoring (& fixes, maybe)
Diffstat (limited to 'model/stream.py')
-rw-r--r-- | model/stream.py | 94 |
1 files changed, 94 insertions, 0 deletions
diff --git a/model/stream.py b/model/stream.py new file mode 100644 index 0000000..88b1d7f --- /dev/null +++ b/model/stream.py @@ -0,0 +1,94 @@ +from fuel.transformers import Batch, Padding, Mapping, SortMapping, Unpack, MultiProcessing +from fuel.streams import DataStream +from fuel.schemes import ConstantScheme, ShuffledExampleScheme + +from theano import tensor + +import data +from data import transformers +from data.hdf5 import TaxiDataset, TaxiStream + + +class StreamRec(object): + def __init__(self, config): + self.config = config + + def train(self, req_vars): + stream = TaxiDataset('train', data.traintest_ds) + + if hasattr(self.config, 'use_cuts_for_training') and self.config.use_cuts_for_training: + stream = DataStream(stream, iteration_scheme=TaxiTimeCutScheme()) + else: + stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples)) + + if not data.tvt: + valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',)) + valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0] + stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids) + + if hasattr(self.config, 'max_splits'): + stream = transformers.TaxiGenerateSplits(stream, max_splits=self.config.max_splits) + elif not data.tvt: + stream = transformers.add_destination(stream) + + stream = transformers.TaxiExcludeEmptyTrips(stream) + stream = transformers.taxi_add_datetime(stream) + + stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) + + stream = transformers.balanced_batch(stream, key='latitude', + batch_size=self.config.batch_size, + batch_sort_size=self.config.batch_sort_size) + + stream = Padding(stream, mask_sources=['latitude', 'longitude']) + + stream = transformers.Select(stream, req_vars) + + stream = MultiProcessing(stream) + + return stream + + def valid(self, req_vars): + stream = TaxiStream(data.valid_set, data.valid_ds) + + stream = transformers.taxi_add_datetime(stream) + stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) + + stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) + stream = Padding(stream, mask_sources=['latitude', 'longitude']) + stream = transformers.Select(stream, req_vars) + return stream + + def test(self, req_vars): + stream = TaxiStream('test', data.traintest_ds) + + stream = transformers.taxi_add_datetime(stream) + stream = transformers.taxi_remove_test_only_clients(stream) + + stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) + + stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) + stream = Padding(stream, mask_sources=['latitude', 'longitude']) + stream = transformers.Select(stream, req_vars) + return stream + + def inputs(self): + return {'call_type': tensor.bvector('call_type'), + 'origin_call': tensor.ivector('origin_call'), + 'origin_stand': tensor.bvector('origin_stand'), + 'taxi_id': tensor.wvector('taxi_id'), + 'timestamp': tensor.ivector('timestamp'), + 'day_type': tensor.bvector('day_type'), + 'missing_data': tensor.bvector('missing_data'), + 'latitude': tensor.matrix('latitude'), + 'longitude': tensor.matrix('longitude'), + 'latitude_mask': tensor.matrix('latitude_mask'), + 'longitude_mask': tensor.matrix('longitude_mask'), + 'destination_latitude': tensor.vector('destination_latitude'), + 'destination_longitude': tensor.vector('destination_longitude'), + 'travel_time': tensor.ivector('travel_time'), + 'input_time': tensor.ivector('input_time'), + 'week_of_year': tensor.bvector('week_of_year'), + 'day_of_week': tensor.bvector('day_of_week'), + 'qhour_of_day': tensor.bvector('qhour_of_day')} + |