diff options
Diffstat (limited to 'data/transformers.py')
-rw-r--r-- | data/transformers.py | 127 |
1 files changed, 127 insertions, 0 deletions
diff --git a/data/transformers.py b/data/transformers.py new file mode 100644 index 0000000..1cc4834 --- /dev/null +++ b/data/transformers.py @@ -0,0 +1,127 @@ +import datetime +import random + +import numpy +import theano +from fuel.transformers import Transformer + +import data + + +def at_least_k(k, v, pad_at_begin, is_longitude): + if len(v) == 0: + v = numpy.array([data.porto_center[1 if is_longitude else 0]], dtype=theano.config.floatX) + if len(v) < k: + if pad_at_begin: + v = numpy.concatenate((numpy.full((k - len(v),), v[0]), v)) + else: + v = numpy.concatenate((v, numpy.full((k - len(v),), v[-1]))) + return v + + +class Select(Transformer): + def __init__(self, data_stream, sources): + super(Select, self).__init__(data_stream) + self.ids = [data_stream.sources.index(source) for source in sources] + self.sources=sources + + def get_data(self, request=None): + if request is not None: + raise ValueError + data=next(self.child_epoch_iterator) + return [data[id] for id in self.ids] + +class TaxiGenerateSplits(Transformer): + def __init__(self, data_stream, max_splits=-1): + super(TaxiGenerateSplits, self).__init__(data_stream) + self.sources = data_stream.sources + ('destination_latitude', 'destination_longitude', 'time') + self.max_splits = max_splits + self.data = None + self.splits = [] + self.isplit = 0 + self.id_latitude = data_stream.sources.index('latitude') + self.id_longitude = data_stream.sources.index('longitude') + + def get_data(self, request=None): + if request is not None: + raise ValueError + while self.isplit >= len(self.splits): + self.data = next(self.child_epoch_iterator) + self.splits = range(len(self.data[self.id_longitude])) + random.shuffle(self.splits) + if self.max_splits != -1 and len(self.splits) > self.max_splits: + self.splits = self.splits[:self.max_splits] + self.isplit = 0 + + i = self.isplit + self.isplit += 1 + n = self.splits[i]+1 + + r = list(self.data) + + r[self.id_latitude] = numpy.array(r[self.id_latitude][:n], dtype=theano.config.floatX) + r[self.id_longitude] = numpy.array(r[self.id_longitude][:n], dtype=theano.config.floatX) + + dlat = numpy.float32(self.data[self.id_latitude][-1]) + dlon = numpy.float32(self.data[self.id_longitude][-1]) + + return tuple(r + [dlat, dlon, 15 * (len(self.data[self.id_longitude]) - 1)]) + +class TaxiAddFirstK(Transformer): + def __init__(self, k, stream): + super(TaxiAddFirstK, self).__init__(stream) + self.sources = stream.sources + ('first_k_latitude', 'first_k_longitude') + self.id_latitude = stream.sources.index('latitude') + self.id_longitude = stream.sources.index('longitude') + self.k = k + def get_data(self, request=None): + if request is not None: raise ValueError + data = next(self.child_epoch_iterator) + first_k = (numpy.array(at_least_k(self.k, data[self.id_latitude], False, False)[:self.k], + dtype=theano.config.floatX), + numpy.array(at_least_k(self.k, data[self.id_longitude], False, True)[:self.k], + dtype=theano.config.floatX)) + return data + first_k + +class TaxiAddLastK(Transformer): + def __init__(self, k, stream): + super(TaxiAddLastK, self).__init__(stream) + self.sources = stream.sources + ('last_k_latitude', 'last_k_longitude') + self.id_latitude = stream.sources.index('latitude') + self.id_longitude = stream.sources.index('longitude') + self.k = k + def get_data(self, request=None): + if request is not None: raise ValueError + data = next(self.child_epoch_iterator) + last_k = (numpy.array(at_least_k(self.k, data[self.id_latitude], True, False)[-self.k:], + dtype=theano.config.floatX), + numpy.array(at_least_k(self.k, data[self.id_longitude], True, True)[-self.k:], + dtype=theano.config.floatX)) + return data + last_k + +class TaxiAddDateTime(Transformer): + def __init__(self, stream): + super(TaxiAddDateTime, self).__init__(stream) + self.sources = stream.sources + ('week_of_year', 'day_of_week', 'qhour_of_day') + self.id_timestamp = stream.sources.index('timestamp') + def get_data(self, request=None): + if request is not None: raise ValueError + data = next(self.child_epoch_iterator) + ts = data[self.id_timestamp] + date = datetime.datetime.utcfromtimestamp(ts) + yearweek = date.isocalendar()[1] - 1 + info = ((51 if yearweek == 52 else yearweek), date.weekday(), date.hour * 4 + date.minute / 15) + return data + info + +class TaxiExcludeTrips(Transformer): + def __init__(self, exclude_list, stream): + super(TaxiExcludeTrips, self).__init__(stream) + self.id_trip_id = stream.sources.index('trip_id') + self.exclude = {v: True for v in exclude_list} + def get_data(self, request=None): + if request is not None: raise ValueError + while True: + data = next(self.child_epoch_iterator) + if not data[self.id_trip_id] in self.exclude: break + return data + |