aboutsummaryrefslogtreecommitdiff
path: root/transformers.py
blob: 79e8327a2513dfbdb8ac6677e13394d607457959 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from fuel.transformers import Transformer, Filter, Mapping
import numpy
import theano
import random
import data

def at_least_k(k, v, pad_at_begin, is_longitude):
    if len(v) == 0:
        v = numpy.array([data.porto_center[1 if is_longitude else 0]], dtype=theano.config.floatX)
    if len(v) < k:
        if pad_at_begin:
            v = numpy.concatenate((numpy.full((k - len(v),), v[0]), v))
        else:
            v = numpy.concatenate((v, numpy.full((k - len(v),), v[-1])))
    return v


class Select(Transformer):
    def __init__(self, data_stream, sources):
        super(Select, self).__init__(data_stream)
        self.ids = [data_stream.sources.index(source) for source in sources]
        self.sources=sources

    def get_data(self, request=None):
        if request is not None:
            raise ValueError
        data=next(self.child_epoch_iterator)
        return [data[id] for id in self.ids]
        
class TaxiGenerateSplits(Transformer):
    def __init__(self, data_stream, max_splits=-1):
        super(TaxiGenerateSplits, self).__init__(data_stream)
        self.sources = data_stream.sources + ('destination_latitude', 'destination_longitude')
        self.max_splits = max_splits
        self.data = None
        self.splits = []
        self.isplit = 0
        self.id_latitude = data_stream.sources.index('latitude')
        self.id_longitude = data_stream.sources.index('longitude')

    def get_data(self, request=None):
        if request is not None:
            raise ValueError
        while self.isplit >= len(self.splits):
            self.data = next(self.child_epoch_iterator)
            self.splits = range(len(self.data[self.id_polyline]))
            random.shuffle_array(self.splits)
            if self.max_splits != -1 and len(self.splits) > self.max_splits:
                self.splits = self.splits[:self.max_splits]
            self.isplit = 0
        
        i = self.isplit
        self.isplit += 1
        n = self.splits[i]+1

        r = list(self.data)

        r[self.id_latitude] = r[self.id_latitude][:n]
        r[self.id_longitude] = r[self.id_longitude][:n]

        dlat = self.data[self.id_latitude][-1]
        dlon = self.data[self.id_longitude][-1]

        return tuple(r + [dlat, dlon])


class first_k(object):
    def __init__(self, k, id_latitude, id_longitude):
        self.k = k
        self.id_latitude = id_latitude
        self.id_longitude = id_longitude
    def __call__(self, data): 
        return (numpy.array(at_least_k(self.k, data[self.id_latitude], False, False)[:self.k],
                            dtype=theano.config.floatX),
                numpy.array(at_least_k(self.k, data[self.id_longitude], False, True)[:self.k],
                            dtype=theano.config.floatX))
def add_first_k(k, stream):
    id_latitude = stream.sources.index('latitude')
    id_longitude = stream.sources.index('longitude')
    return Mapping(stream, first_k(k, id_latitude, id_longitude), ('first_k_latitude', 'first_k_longitude'))

class random_k(object):
    def __init__(self, k, id_latitude, id_longitude):
        self.k = k
        self.id_latitude = id_latitude
        self.id_longitude = id_longitude
    def __call__(self, x):
        lat = at_least_k(self.k, x[self.id_latitude], True, False)
        lon = at_least_k(self.k, x[self.id_longitude], True, True)
        loc = random.randrange(len(lat)-self.k+1)
        return (numpy.array(lat[loc:loc+self.k], dtype=theano.config.floatX),
                numpy.array(lon[loc:loc+self.k], dtype=theano.config.floatX))
def add_random_k(k, stream):
    id_latitude = stream.sources.index('latitude')
    id_longitude = stream.sources.index('longitude')
    return Mapping(stream, random_k(k, id_latitude, id_longitude), ('last_k_latitude', 'last_k_longitude'))

class last_k(object):
    def __init__(self, k, id_latitude, id_longitude):
        self.k = k
        self.id_latitude = id_latitude
        self.id_longitude = id_longitude
    def __call__(self, data):
        return (numpy.array(at_least_k(self.k, data[self.id_latitude], True, False)[-self.k:],
                            dtype=theano.config.floatX),
                numpy.array(at_least_k(self.k, data[self.id_longitude], True, True)[-self.k:],
                            dtype=theano.config.floatX))
def add_last_k(k, stream):
    id_latitude = stream.sources.index('latitude')
    id_longitude = stream.sources.index('longitude')
    return Mapping(stream, last_k(k, id_latitude, id_longitude), ('last_k_latitude', 'last_k_longitude'))

class destination(object):
    def __init__(self, id_latitude, id_longitude):
        self.id_latitude = id_latitude
        self.id_longitude = id_longitude
    def __call__(self, data):
        return (numpy.array(at_least_k(1, data[self.id_latitude], True, False)[-1],
                            dtype=theano.config.floatX),
                numpy.array(at_least_k(1, data[self.id_longitude], True, True)[-1],
                            dtype=theano.config.floatX))
def add_destination(stream):
    id_latitude = stream.sources.index('latitude')
    id_longitude = stream.sources.index('longitude')
    return Mapping(stream, destination(id_latitude, id_longitude), ('destination_latitude', 'destination_longitude'))


class trip_filter(object):
    def __init__(self, id_trip_id, exclude):
        self.id_trip_id = id_trip_id
        self.exclude = exclude
    def __call__(self, data):
        if data[self.id_trip_id] in self.exclude:
            return False
        else:
            return True
def filter_out_trips(exclude_trips, stream):
    id_trip_id = stream.sources.index('trip_id')
    return Filter(stream, trip_filter(id_trip_id, exclude_trips))