diff options
author | AdeB <adbrebs@gmail.com> | 2015-05-05 22:15:29 -0400 |
---|---|---|
committer | AdeB <adbrebs@gmail.com> | 2015-05-05 22:15:29 -0400 |
commit | c29a0d3f22134a8d1f5d557b325f6779c5961546 (patch) | |
tree | 6fa431d9b3595b5d2d11089920aa07ea43172d90 /data/hdf5.py | |
parent | f4d3ee6449217535bdbe19ac9c5fdd825d71b0d3 (diff) | |
parent | 1f2ff96e6480a62089fcac35154a956c218ed678 (diff) | |
download | taxi-c29a0d3f22134a8d1f5d557b325f6779c5961546.tar.gz taxi-c29a0d3f22134a8d1f5d557b325f6779c5961546.zip |
Merge branch 'master' of github.com:adbrebs/taxi
Diffstat (limited to 'data/hdf5.py')
-rw-r--r-- | data/hdf5.py | 61 |
1 files changed, 61 insertions, 0 deletions
diff --git a/data/hdf5.py b/data/hdf5.py new file mode 100644 index 0000000..d848023 --- /dev/null +++ b/data/hdf5.py @@ -0,0 +1,61 @@ +import os + +import h5py +from fuel.datasets import H5PYDataset +from fuel.iterator import DataIterator +from fuel.schemes import SequentialExampleScheme +from fuel.streams import DataStream + +import data + + +class TaxiDataset(H5PYDataset): + def __init__(self, which_set, filename='data.hdf5', **kwargs): + self.filename = filename + kwargs.setdefault('load_in_memory', True) + super(TaxiDataset, self).__init__(self.data_path, which_set, **kwargs) + + @property + def data_path(self): + return os.path.join(data.path, self.filename) + +class TaxiStream(DataStream): + def __init__(self, which_set, filename='data.hdf5', iteration_scheme=None, **kwargs): + dataset = TaxiDataset(which_set, filename, **kwargs) + if iteration_scheme is None: + iteration_scheme = SequentialExampleScheme(dataset.num_examples) + super(TaxiStream, self).__init__(dataset, iteration_scheme=iteration_scheme) + +_origin_calls = None +_reverse_origin_calls = None + +def origin_call_unnormalize(x): + if _origin_calls is None: + _origin_calls = h5py.File(os.path.join(data.path, 'data.hdf5'), 'r')['unique_origin_call'] + return _origin_calls[x] + +def origin_call_normalize(x): + if _reverse_origin_calls is None: + origin_call_unnormalize(0) + _reverse_origin_calls = { _origin_calls[i]: i for i in range(_origin_calls.shape[0]) } + return _reverse_origin_calls[x] + +_taxi_ids = None +_reverse_taxi_ids = None + +def taxi_id_unnormalize(x): + if _taxi_ids is None: + _taxi_ids = h5py.File(os.path.join(data.path, 'data.hdf5'), 'r')['unique_taxi_id'] + return _taxi_ids[x] + +def taxi_id_normalize(x): + if _reverse_taxi_ids is None: + taxi_id_unnormalize(0) + _reverse_taxi_ids = { _taxi_ids[i]: i for i in range(_taxi_ids.shape[0]) } + return _reverse_taxi_ids[x] + +def taxi_it(which_set, filename='data.hdf5', sub=None, as_dict=True): + dataset = TaxiDataset(which_set, filename) + if sub is None: + sub = xrange(dataset.num_examples) + return DataIterator(DataStream(dataset), iter(sub), as_dict) |