diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-04 13:15:23 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-04 13:15:23 -0400 |
commit | de76aae44b6c0cbe9ab42c7ae215c3ae9e4e4055 (patch) | |
tree | 09c09a12861f0f6826cd33e3b77eba9a07076c49 /transformers.py | |
parent | 43e106e6630030dd34813295fe1d07bb86025402 (diff) | |
download | taxi-de76aae44b6c0cbe9ab42c7ae215c3ae9e4e4055.tar.gz taxi-de76aae44b6c0cbe9ab42c7ae215c3ae9e4e4055.zip |
Add TaxiGenerateSplits
Diffstat (limited to 'transformers.py')
-rw-r--r-- | transformers.py | 51 |
1 files changed, 51 insertions, 0 deletions
diff --git a/transformers.py b/transformers.py index 5ad9a87..79e8327 100644 --- a/transformers.py +++ b/transformers.py @@ -27,6 +27,42 @@ class Select(Transformer): data=next(self.child_epoch_iterator) return [data[id] for id in self.ids] +class TaxiGenerateSplits(Transformer): + def __init__(self, data_stream, max_splits=-1): + super(TaxiGenerateSplits, self).__init__(data_stream) + self.sources = data_stream.sources + ('destination_latitude', 'destination_longitude') + self.max_splits = max_splits + self.data = None + self.splits = [] + self.isplit = 0 + self.id_latitude = data_stream.sources.index('latitude') + self.id_longitude = data_stream.sources.index('longitude') + + def get_data(self, request=None): + if request is not None: + raise ValueError + while self.isplit >= len(self.splits): + self.data = next(self.child_epoch_iterator) + self.splits = range(len(self.data[self.id_polyline])) + random.shuffle_array(self.splits) + if self.max_splits != -1 and len(self.splits) > self.max_splits: + self.splits = self.splits[:self.max_splits] + self.isplit = 0 + + i = self.isplit + self.isplit += 1 + n = self.splits[i]+1 + + r = list(self.data) + + r[self.id_latitude] = r[self.id_latitude][:n] + r[self.id_longitude] = r[self.id_longitude][:n] + + dlat = self.data[self.id_latitude][-1] + dlon = self.data[self.id_longitude][-1] + + return tuple(r + [dlat, dlon]) + class first_k(object): def __init__(self, k, id_latitude, id_longitude): @@ -87,3 +123,18 @@ def add_destination(stream): id_latitude = stream.sources.index('latitude') id_longitude = stream.sources.index('longitude') return Mapping(stream, destination(id_latitude, id_longitude), ('destination_latitude', 'destination_longitude')) + + +class trip_filter(object): + def __init__(self, id_trip_id, exclude): + self.id_trip_id = id_trip_id + self.exclude = exclude + def __call__(self, data): + if data[self.id_trip_id] in self.exclude: + return False + else: + return True +def filter_out_trips(exclude_trips, stream): + id_trip_id = stream.sources.index('trip_id') + return Filter(stream, trip_filter(id_trip_id, exclude_trips)) + |