aboutsummaryrefslogtreecommitdiff
path: root/model/rnn.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-07-24 11:30:27 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-07-24 11:30:40 -0400
commitff49937eef024916ac4560ce0134d94006e9e2e5 (patch)
tree5b8faba873efc03151ee457c459510c8f64b65f3 /model/rnn.py
parentac49e3cb892e0278ea1d52afdc314322000fae27 (diff)
downloadtaxi-ff49937eef024916ac4560ce0134d94006e9e2e5.tar.gz
taxi-ff49937eef024916ac4560ce0134d94006e9e2e5.zip
RNN & Bidir RNN refactoring (& fixes, maybe)
Diffstat (limited to 'model/rnn.py')
-rw-r--r--model/rnn.py63
1 files changed, 2 insertions, 61 deletions
diff --git a/model/rnn.py b/model/rnn.py
index b4c6550..1b192ae 100644
--- a/model/rnn.py
+++ b/model/rnn.py
@@ -16,6 +16,8 @@ from data.hdf5 import TaxiDataset, TaxiStream
import error
+from model.stream import StreamRec as Stream
+
class RNN(Initializable):
@lazy()
def __init__(self, config, rec_input_len=2, output_dim=2, **kwargs):
@@ -141,64 +143,3 @@ class RNN(Initializable):
return self.inputs + ['destination_latitude', 'destination_longitude']
-class Stream(object):
- def __init__(self, config):
- self.config = config
-
- def train(self, req_vars):
- stream = TaxiDataset('train', data.traintest_ds)
- stream = DataStream(stream, iteration_scheme=ShuffledExampleScheme(stream.num_examples))
- if not data.tvt:
- valid = TaxiDataset(data.valid_set, data.valid_ds, sources=('trip_id',))
- valid_trips_ids = valid.get_data(None, slice(0, valid.num_examples))[0]
- stream = transformers.TaxiExcludeTrips(stream, valid_trips_ids)
-
- stream = transformers.TaxiExcludeEmptyTrips(stream)
- stream = transformers.taxi_add_datetime(stream)
- if not data.tvt:
- stream = transformers.add_destination(stream)
- stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
-
- stream = transformers.balanced_batch(stream, key='latitude', batch_size=self.config.batch_size, batch_sort_size=self.config.batch_sort_size)
- stream = Padding(stream, mask_sources=['latitude', 'longitude'])
- stream = transformers.Select(stream, req_vars)
- return stream
-
- def valid(self, req_vars):
- stream = TaxiStream(data.valid_set, data.valid_ds)
- stream = transformers.taxi_add_datetime(stream)
- stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
-
- stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
- stream = Padding(stream, mask_sources=['latitude', 'longitude'])
- stream = transformers.Select(stream, req_vars)
- return stream
-
- def test(self, req_vars):
- stream = TaxiStream('test', data.traintest_ds)
- stream = transformers.taxi_add_datetime(stream)
- stream = transformers.taxi_remove_test_only_clients(stream)
- stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
-
- stream = Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
- stream = Padding(stream, mask_sources=['latitude', 'longitude'])
- stream = transformers.Select(stream, req_vars)
- return stream
-
- def inputs(self):
- return {'call_type': tensor.bvector('call_type'),
- 'origin_call': tensor.ivector('origin_call'),
- 'origin_stand': tensor.bvector('origin_stand'),
- 'taxi_id': tensor.wvector('taxi_id'),
- 'timestamp': tensor.ivector('timestamp'),
- 'day_type': tensor.bvector('day_type'),
- 'missing_data': tensor.bvector('missing_data'),
- 'latitude': tensor.matrix('latitude'),
- 'longitude': tensor.matrix('longitude'),
- 'latitude_mask': tensor.matrix('latitude_mask'),
- 'longitude_mask': tensor.matrix('longitude_mask'),
- 'week_of_year': tensor.bvector('week_of_year'),
- 'day_of_week': tensor.bvector('day_of_week'),
- 'qhour_of_day': tensor.bvector('qhour_of_day'),
- 'destination_latitude': tensor.vector('destination_latitude'),
- 'destination_longitude': tensor.vector('destination_longitude')}