aboutsummaryrefslogtreecommitdiff
path: root/transformers.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-05-04 13:15:23 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-05-04 13:15:23 -0400
commitde76aae44b6c0cbe9ab42c7ae215c3ae9e4e4055 (patch)
tree09c09a12861f0f6826cd33e3b77eba9a07076c49 /transformers.py
parent43e106e6630030dd34813295fe1d07bb86025402 (diff)
downloadtaxi-de76aae44b6c0cbe9ab42c7ae215c3ae9e4e4055.tar.gz
taxi-de76aae44b6c0cbe9ab42c7ae215c3ae9e4e4055.zip
Add TaxiGenerateSplits
Diffstat (limited to 'transformers.py')
-rw-r--r--transformers.py51
1 files changed, 51 insertions, 0 deletions
diff --git a/transformers.py b/transformers.py
index 5ad9a87..79e8327 100644
--- a/transformers.py
+++ b/transformers.py
@@ -27,6 +27,42 @@ class Select(Transformer):
data=next(self.child_epoch_iterator)
return [data[id] for id in self.ids]
+class TaxiGenerateSplits(Transformer):
+ def __init__(self, data_stream, max_splits=-1):
+ super(TaxiGenerateSplits, self).__init__(data_stream)
+ self.sources = data_stream.sources + ('destination_latitude', 'destination_longitude')
+ self.max_splits = max_splits
+ self.data = None
+ self.splits = []
+ self.isplit = 0
+ self.id_latitude = data_stream.sources.index('latitude')
+ self.id_longitude = data_stream.sources.index('longitude')
+
+ def get_data(self, request=None):
+ if request is not None:
+ raise ValueError
+ while self.isplit >= len(self.splits):
+ self.data = next(self.child_epoch_iterator)
+ self.splits = range(len(self.data[self.id_polyline]))
+ random.shuffle_array(self.splits)
+ if self.max_splits != -1 and len(self.splits) > self.max_splits:
+ self.splits = self.splits[:self.max_splits]
+ self.isplit = 0
+
+ i = self.isplit
+ self.isplit += 1
+ n = self.splits[i]+1
+
+ r = list(self.data)
+
+ r[self.id_latitude] = r[self.id_latitude][:n]
+ r[self.id_longitude] = r[self.id_longitude][:n]
+
+ dlat = self.data[self.id_latitude][-1]
+ dlon = self.data[self.id_longitude][-1]
+
+ return tuple(r + [dlat, dlon])
+
class first_k(object):
def __init__(self, k, id_latitude, id_longitude):
@@ -87,3 +123,18 @@ def add_destination(stream):
id_latitude = stream.sources.index('latitude')
id_longitude = stream.sources.index('longitude')
return Mapping(stream, destination(id_latitude, id_longitude), ('destination_latitude', 'destination_longitude'))
+
+
+class trip_filter(object):
+ def __init__(self, id_trip_id, exclude):
+ self.id_trip_id = id_trip_id
+ self.exclude = exclude
+ def __call__(self, data):
+ if data[self.id_trip_id] in self.exclude:
+ return False
+ else:
+ return True
+def filter_out_trips(exclude_trips, stream):
+ id_trip_id = stream.sources.index('trip_id')
+ return Filter(stream, trip_filter(id_trip_id, exclude_trips))
+