1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
|
import os
import h5py
from fuel.datasets import H5PYDataset
from fuel.iterator import DataIterator
from fuel.schemes import SequentialExampleScheme
from fuel.streams import DataStream
import data
class TaxiDataset(H5PYDataset):
def __init__(self, which_set, filename='data.hdf5', **kwargs):
self.filename = filename
kwargs.setdefault('load_in_memory', True)
super(TaxiDataset, self).__init__(self.data_path, which_set, **kwargs)
@property
def data_path(self):
return os.path.join(data.path, self.filename)
class TaxiStream(DataStream):
def __init__(self, which_set, filename='data.hdf5', iteration_scheme=None, **kwargs):
dataset = TaxiDataset(which_set, filename, **kwargs)
if iteration_scheme is None:
iteration_scheme = SequentialExampleScheme(dataset.num_examples)
super(TaxiStream, self).__init__(dataset, iteration_scheme=iteration_scheme)
_origin_calls = None
_reverse_origin_calls = None
def origin_call_unnormalize(x):
if _origin_calls is None:
_origin_calls = h5py.File(os.path.join(data.path, 'data.hdf5'), 'r')['unique_origin_call']
return _origin_calls[x]
def origin_call_normalize(x):
if _reverse_origin_calls is None:
origin_call_unnormalize(0)
_reverse_origin_calls = { _origin_calls[i]: i for i in range(_origin_calls.shape[0]) }
return _reverse_origin_calls[x]
_taxi_ids = None
_reverse_taxi_ids = None
def taxi_id_unnormalize(x):
if _taxi_ids is None:
_taxi_ids = h5py.File(os.path.join(data.path, 'data.hdf5'), 'r')['unique_taxi_id']
return _taxi_ids[x]
def taxi_id_normalize(x):
if _reverse_taxi_ids is None:
taxi_id_unnormalize(0)
_reverse_taxi_ids = { _taxi_ids[i]: i for i in range(_taxi_ids.shape[0]) }
return _reverse_taxi_ids[x]
def taxi_it(which_set, filename='data.hdf5', sub=None, as_dict=True):
dataset = TaxiDataset(which_set, filename)
if sub is None:
sub = xrange(dataset.num_examples)
return DataIterator(DataStream(dataset), iter(sub), as_dict)
|