aboutsummaryrefslogtreecommitdiff
path: root/data/transformers.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-05-08 13:59:44 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-05-08 13:59:44 -0400
commit0ecac7973fd02f44af9c8bc5765f7c159c94b23a (patch)
treeec28ec6253d84b3c55f8fa88adce4e720d26bd07 /data/transformers.py
parent1ffd1fc355f6fddcb6cd3d93c0df58513d064472 (diff)
downloadtaxi-0ecac7973fd02f44af9c8bc5765f7c159c94b23a.tar.gz
taxi-0ecac7973fd02f44af9c8bc5765f7c159c94b23a.zip
Fusion AddFirstK and AddLastK transformers, and add 'input_time' field.
Diffstat (limited to 'data/transformers.py')
-rw-r--r--data/transformers.py23
1 files changed, 7 insertions, 16 deletions
diff --git a/data/transformers.py b/data/transformers.py
index e3ff7b1..57747fc 100644
--- a/data/transformers.py
+++ b/data/transformers.py
@@ -67,10 +67,12 @@ class TaxiGenerateSplits(Transformer):
return tuple(r + [dlat, dlon, 15 * (len(self.data[self.id_longitude]) - 1)])
-class TaxiAddFirstK(Transformer):
+class TaxiAddFirstLastLen(Transformer):
def __init__(self, k, stream):
- super(TaxiAddFirstK, self).__init__(stream)
- self.sources = stream.sources + ('first_k_latitude', 'first_k_longitude')
+ super(TaxiAddFirstLastLen, self).__init__(stream)
+ self.sources = stream.sources + ('first_k_latitude', 'first_k_longitude',
+ 'last_k_latitude', 'last_k_longitude',
+ 'input_time')
self.id_latitude = stream.sources.index('latitude')
self.id_longitude = stream.sources.index('longitude')
self.k = k
@@ -81,23 +83,12 @@ class TaxiAddFirstK(Transformer):
dtype=theano.config.floatX),
numpy.array(at_least_k(self.k, data[self.id_longitude], False, True)[:self.k],
dtype=theano.config.floatX))
- return data + first_k
-
-class TaxiAddLastK(Transformer):
- def __init__(self, k, stream):
- super(TaxiAddLastK, self).__init__(stream)
- self.sources = stream.sources + ('last_k_latitude', 'last_k_longitude')
- self.id_latitude = stream.sources.index('latitude')
- self.id_longitude = stream.sources.index('longitude')
- self.k = k
- def get_data(self, request=None):
- if request is not None: raise ValueError
- data = next(self.child_epoch_iterator)
last_k = (numpy.array(at_least_k(self.k, data[self.id_latitude], True, False)[-self.k:],
dtype=theano.config.floatX),
numpy.array(at_least_k(self.k, data[self.id_longitude], True, True)[-self.k:],
dtype=theano.config.floatX))
- return data + last_k
+ input_time = (15 * (len(data[self.id_latitude]) - 1),)
+ return data + first_k + last_k + input_time
class TaxiAddDateTime(Transformer):
def __init__(self, stream):