aboutsummaryrefslogtreecommitdiff
path: root/data/transformers.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/transformers.py')
-rw-r--r--data/transformers.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/data/transformers.py b/data/transformers.py
index 57747fc..1b82dae 100644
--- a/data/transformers.py
+++ b/data/transformers.py
@@ -64,8 +64,9 @@ class TaxiGenerateSplits(Transformer):
dlat = numpy.float32(self.data[self.id_latitude][-1])
dlon = numpy.float32(self.data[self.id_longitude][-1])
+ ttime = numpy.int32(15 * (len(self.data[self.id_longitude]) - 1))
- return tuple(r + [dlat, dlon, 15 * (len(self.data[self.id_longitude]) - 1)])
+ return tuple(r + [dlat, dlon, ttime])
class TaxiAddFirstLastLen(Transformer):
def __init__(self, k, stream):
@@ -87,7 +88,7 @@ class TaxiAddFirstLastLen(Transformer):
dtype=theano.config.floatX),
numpy.array(at_least_k(self.k, data[self.id_longitude], True, True)[-self.k:],
dtype=theano.config.floatX))
- input_time = (15 * (len(data[self.id_latitude]) - 1),)
+ input_time = (numpy.int32(15 * (len(data[self.id_latitude]) - 1)),)
return data + first_k + last_k + input_time
class TaxiAddDateTime(Transformer):
@@ -101,7 +102,9 @@ class TaxiAddDateTime(Transformer):
ts = data[self.id_timestamp]
date = datetime.datetime.utcfromtimestamp(ts)
yearweek = date.isocalendar()[1] - 1
- info = ((51 if yearweek == 52 else yearweek), date.weekday(), date.hour * 4 + date.minute / 15)
+ info = (numpy.int8(51 if yearweek == 52 else yearweek),
+ numpy.int8(date.weekday()),
+ numpy.int8(date.hour * 4 + date.minute / 15))
return data + info
class TaxiExcludeTrips(Transformer):