aboutsummaryrefslogtreecommitdiff
path: root/transformers.py
blob: d6ed61163ee96cedb6937757ee6595907a9cefab (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from fuel.transformers import Transformer, Filter, Mapping
import numpy
import theano
import random
import data

import datetime

def at_least_k(k, v, pad_at_begin, is_longitude):
    if len(v) == 0:
        v = numpy.array([data.porto_center[1 if is_longitude else 0]], dtype=theano.config.floatX)
    if len(v) < k:
        if pad_at_begin:
            v = numpy.concatenate((numpy.full((k - len(v),), v[0]), v))
        else:
            v = numpy.concatenate((v, numpy.full((k - len(v),), v[-1])))
    return v


class Select(Transformer):
    def __init__(self, data_stream, sources):
        super(Select, self).__init__(data_stream)
        self.ids = [data_stream.sources.index(source) for source in sources]
        self.sources=sources

    def get_data(self, request=None):
        if request is not None:
            raise ValueError
        data=next(self.child_epoch_iterator)
        return [data[id] for id in self.ids]
        
class TaxiGenerateSplits(Transformer):
    def __init__(self, data_stream, max_splits=-1):
        super(TaxiGenerateSplits, self).__init__(data_stream)
        self.sources = data_stream.sources + ('destination_latitude', 'destination_longitude')
        self.max_splits = max_splits
        self.data = None
        self.splits = []
        self.isplit = 0
        self.id_latitude = data_stream.sources.index('latitude')
        self.id_longitude = data_stream.sources.index('longitude')

    def get_data(self, request=None):
        if request is not None:
            raise ValueError
        while self.isplit >= len(self.splits):
            self.data = next(self.child_epoch_iterator)
            self.splits = range(len(self.data[self.id_longitude]))
            random.shuffle(self.splits)
            if self.max_splits != -1 and len(self.splits) > self.max_splits:
                self.splits = self.splits[:self.max_splits]
            self.isplit = 0
        
        i = self.isplit
        self.isplit += 1
        n = self.splits[i]+1

        r = list(self.data)

        r[self.id_latitude] = numpy.array(r[self.id_latitude][:n], dtype=theano.config.floatX)
        r[self.id_longitude] = numpy.array(r[self.id_longitude][:n], dtype=theano.config.floatX)

        dlat = numpy.float32(self.data[self.id_latitude][-1])
        dlon = numpy.float32(self.data[self.id_longitude][-1])

        return tuple(r + [dlat, dlon])

class TaxiAddFirstK(Transformer):
    def __init__(self, k, stream):
        super(TaxiAddFirstK, self).__init__(stream)
        self.sources = stream.sources + ('first_k_latitude', 'first_k_longitude')
        self.id_latitude = stream.sources.index('latitude')
        self.id_longitude = stream.sources.index('longitude')
        self.k = k
    def get_data(self, request=None):
        if request is not None: raise ValueError
        data = next(self.child_epoch_iterator)
        first_k = (numpy.array(at_least_k(self.k, data[self.id_latitude], False, False)[:self.k],
                               dtype=theano.config.floatX),
                   numpy.array(at_least_k(self.k, data[self.id_longitude], False, True)[:self.k],
                               dtype=theano.config.floatX))
        return data + first_k

class TaxiAddLastK(Transformer):
    def __init__(self, k, stream):
        super(TaxiAddLastK, self).__init__(stream)
        self.sources = stream.sources + ('last_k_latitude', 'last_k_longitude')
        self.id_latitude = stream.sources.index('latitude')
        self.id_longitude = stream.sources.index('longitude')
        self.k = k
    def get_data(self, request=None):
        if request is not None: raise ValueError
        data = next(self.child_epoch_iterator)
        last_k = (numpy.array(at_least_k(self.k, data[self.id_latitude], True, False)[-self.k:],
                            dtype=theano.config.floatX),
                  numpy.array(at_least_k(self.k, data[self.id_longitude], True, True)[-self.k:],
                              dtype=theano.config.floatX))
        return data + last_k

class TaxiAddDateTime(Transformer):
    def __init__(self, stream):
        super(TaxiAddDateTime, self).__init__(stream)
        self.sources = stream.sources + ('week_of_year', 'day_of_week', 'qhour_of_day')
        self.id_timestamp = stream.sources.index('timestamp')
    def get_data(self, request=None):
        if request is not None: raise ValueError
        data = next(self.child_epoch_iterator)
        ts = data[self.id_timestamp]
        date = datetime.datetime.utcfromtimestamp(ts)
        info = (date.isocalendar()[1] - 1, date.weekday(), date.hour * 4 + date.minute / 15)
        return data + info

class TaxiExcludeTrips(Transformer):
    def __init__(self, exclude_list, stream):
        super(TaxiExcludeTrips, self).__init__(stream)
        self.id_trip_id = stream.sources.index('trip_id')
        self.exclude = {v: True for v in exclude_list}
    def get_data(self, request=None):
        if request is not None: raise ValueError
        while True:
            data = next(self.child_epoch_iterator)
            if not data[self.id_trip_id] in self.exclude: break
        return data