diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-24 11:30:27 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-24 11:30:40 -0400 |
commit | ff49937eef024916ac4560ce0134d94006e9e2e5 (patch) | |
tree | 5b8faba873efc03151ee457c459510c8f64b65f3 /model/rnn.py | |
parent | ac49e3cb892e0278ea1d52afdc314322000fae27 (diff) | |
download | taxi-ff49937eef024916ac4560ce0134d94006e9e2e5.tar.gz taxi-ff49937eef024916ac4560ce0134d94006e9e2e5.zip |
RNN & Bidir RNN refactoring (& fixes, maybe)
Diffstat (limited to 'model/rnn.py')
-rw-r--r-- | model/rnn.py | 63 |
1 files changed, 2 insertions, 61 deletions
diff --git a/model/rnn.py b/model/rnn.py index b4c6550..1b192ae 100644 --- a/model/rnn.py +++ b/model/rnn.py @@ -16,6 +16,8 @@ from data.hdf5 import TaxiDataset, TaxiStream import error +from model.stream import StreamRec as Stream + class RNN(Initializable): @lazy() def __init__(self, config, rec_input_len=2, output_dim=2, **kwargs): @@ -141,64 +143,3 @@ class RNN(Initializable): return self.inputs + ['destination_latitude', 'destination_longitude'] -class Stream(object): - def __init__(self, config): - self.config = config - - def train(self, req_vars): - stream = TaxiDataset('train', data.traintest_ds) - stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples)) - if not data.tvt: - valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',)) - valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0] - stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids) - - stream = transformers.TaxiExcludeEmptyTrips(stream) - stream = transformers.taxi_add_datetime(stream) - if not data.tvt: - stream = transformers.add_destination(stream) - stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) - - stream = transformers.balanced_batch(stream, key='latitude', batch_size=self.config.batch_size, batch_sort_size=self.config.batch_sort_size) - stream = Padding(stream, mask_sources=['latitude', 'longitude']) - stream = transformers.Select(stream, req_vars) - return stream - - def valid(self, req_vars): - stream = TaxiStream(data.valid_set, data.valid_ds) - stream = transformers.taxi_add_datetime(stream) - stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) - - stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) - stream = Padding(stream, mask_sources=['latitude', 'longitude']) - stream = transformers.Select(stream, req_vars) - return stream - - def test(self, req_vars): - stream = TaxiStream('test', data.traintest_ds) - stream = transformers.taxi_add_datetime(stream) - stream = transformers.taxi_remove_test_only_clients(stream) - stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) - - stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) - stream = Padding(stream, mask_sources=['latitude', 'longitude']) - stream = transformers.Select(stream, req_vars) - return stream - - def inputs(self): - return {'call_type': tensor.bvector('call_type'), - 'origin_call': tensor.ivector('origin_call'), - 'origin_stand': tensor.bvector('origin_stand'), - 'taxi_id': tensor.wvector('taxi_id'), - 'timestamp': tensor.ivector('timestamp'), - 'day_type': tensor.bvector('day_type'), - 'missing_data': tensor.bvector('missing_data'), - 'latitude': tensor.matrix('latitude'), - 'longitude': tensor.matrix('longitude'), - 'latitude_mask': tensor.matrix('latitude_mask'), - 'longitude_mask': tensor.matrix('longitude_mask'), - 'week_of_year': tensor.bvector('week_of_year'), - 'day_of_week': tensor.bvector('day_of_week'), - 'qhour_of_day': tensor.bvector('qhour_of_day'), - 'destination_latitude': tensor.vector('destination_latitude'), - 'destination_longitude': tensor.vector('destination_longitude')} |