aboutsummaryrefslogtreecommitdiff
path: root/data/hdf5.py
diff options
context:
space:
mode:
authorAdeB <adbrebs@gmail.com>2015-05-05 22:15:29 -0400
committerAdeB <adbrebs@gmail.com>2015-05-05 22:15:29 -0400
commitc29a0d3f22134a8d1f5d557b325f6779c5961546 (patch)
tree6fa431d9b3595b5d2d11089920aa07ea43172d90 /data/hdf5.py
parentf4d3ee6449217535bdbe19ac9c5fdd825d71b0d3 (diff)
parent1f2ff96e6480a62089fcac35154a956c218ed678 (diff)
downloadtaxi-c29a0d3f22134a8d1f5d557b325f6779c5961546.tar.gz
taxi-c29a0d3f22134a8d1f5d557b325f6779c5961546.zip
Merge branch 'master' of github.com:adbrebs/taxi
Diffstat (limited to 'data/hdf5.py')
-rw-r--r--data/hdf5.py61
1 files changed, 61 insertions, 0 deletions
diff --git a/data/hdf5.py b/data/hdf5.py
new file mode 100644
index 0000000..d848023
--- /dev/null
+++ b/data/hdf5.py
@@ -0,0 +1,61 @@
+import os
+
+import h5py
+from fuel.datasets import H5PYDataset
+from fuel.iterator import DataIterator
+from fuel.schemes import SequentialExampleScheme
+from fuel.streams import DataStream
+
+import data
+
+
+class TaxiDataset(H5PYDataset):
+ def __init__(self, which_set, filename='data.hdf5', **kwargs):
+ self.filename = filename
+ kwargs.setdefault('load_in_memory', True)
+ super(TaxiDataset, self).__init__(self.data_path, which_set, **kwargs)
+
+ @property
+ def data_path(self):
+ return os.path.join(data.path, self.filename)
+
+class TaxiStream(DataStream):
+ def __init__(self, which_set, filename='data.hdf5', iteration_scheme=None, **kwargs):
+ dataset = TaxiDataset(which_set, filename, **kwargs)
+ if iteration_scheme is None:
+ iteration_scheme = SequentialExampleScheme(dataset.num_examples)
+ super(TaxiStream, self).__init__(dataset, iteration_scheme=iteration_scheme)
+
+_origin_calls = None
+_reverse_origin_calls = None
+
+def origin_call_unnormalize(x):
+ if _origin_calls is None:
+ _origin_calls = h5py.File(os.path.join(data.path, 'data.hdf5'), 'r')['unique_origin_call']
+ return _origin_calls[x]
+
+def origin_call_normalize(x):
+ if _reverse_origin_calls is None:
+ origin_call_unnormalize(0)
+ _reverse_origin_calls = { _origin_calls[i]: i for i in range(_origin_calls.shape[0]) }
+ return _reverse_origin_calls[x]
+
+_taxi_ids = None
+_reverse_taxi_ids = None
+
+def taxi_id_unnormalize(x):
+ if _taxi_ids is None:
+ _taxi_ids = h5py.File(os.path.join(data.path, 'data.hdf5'), 'r')['unique_taxi_id']
+ return _taxi_ids[x]
+
+def taxi_id_normalize(x):
+ if _reverse_taxi_ids is None:
+ taxi_id_unnormalize(0)
+ _reverse_taxi_ids = { _taxi_ids[i]: i for i in range(_taxi_ids.shape[0]) }
+ return _reverse_taxi_ids[x]
+
+def taxi_it(which_set, filename='data.hdf5', sub=None, as_dict=True):
+ dataset = TaxiDataset(which_set, filename)
+ if sub is None:
+ sub = xrange(dataset.num_examples)
+ return DataIterator(DataStream(dataset), iter(sub), as_dict)