diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/transformers.py | 23 |
1 files changed, 7 insertions, 16 deletions
diff --git a/data/transformers.py b/data/transformers.py index e3ff7b1..57747fc 100644 --- a/data/transformers.py +++ b/data/transformers.py @@ -67,10 +67,12 @@ class TaxiGenerateSplits(Transformer): return tuple(r + [dlat, dlon, 15 * (len(self.data[self.id_longitude]) - 1)]) -class TaxiAddFirstK(Transformer): +class TaxiAddFirstLastLen(Transformer): def __init__(self, k, stream): - super(TaxiAddFirstK, self).__init__(stream) - self.sources = stream.sources + ('first_k_latitude', 'first_k_longitude') + super(TaxiAddFirstLastLen, self).__init__(stream) + self.sources = stream.sources + ('first_k_latitude', 'first_k_longitude', + 'last_k_latitude', 'last_k_longitude', + 'input_time') self.id_latitude = stream.sources.index('latitude') self.id_longitude = stream.sources.index('longitude') self.k = k @@ -81,23 +83,12 @@ class TaxiAddFirstK(Transformer): dtype=theano.config.floatX), numpy.array(at_least_k(self.k, data[self.id_longitude], False, True)[:self.k], dtype=theano.config.floatX)) - return data + first_k - -class TaxiAddLastK(Transformer): - def __init__(self, k, stream): - super(TaxiAddLastK, self).__init__(stream) - self.sources = stream.sources + ('last_k_latitude', 'last_k_longitude') - self.id_latitude = stream.sources.index('latitude') - self.id_longitude = stream.sources.index('longitude') - self.k = k - def get_data(self, request=None): - if request is not None: raise ValueError - data = next(self.child_epoch_iterator) last_k = (numpy.array(at_least_k(self.k, data[self.id_latitude], True, False)[-self.k:], dtype=theano.config.floatX), numpy.array(at_least_k(self.k, data[self.id_longitude], True, True)[-self.k:], dtype=theano.config.floatX)) - return data + last_k + input_time = (15 * (len(data[self.id_latitude]) - 1),) + return data + first_k + last_k + input_time class TaxiAddDateTime(Transformer): def __init__(self, stream): |