diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/transformers.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/data/transformers.py b/data/transformers.py index 57747fc..1b82dae 100644 --- a/data/transformers.py +++ b/data/transformers.py @@ -64,8 +64,9 @@ class TaxiGenerateSplits(Transformer): dlat = numpy.float32(self.data[self.id_latitude][-1]) dlon = numpy.float32(self.data[self.id_longitude][-1]) + ttime = numpy.int32(15 * (len(self.data[self.id_longitude]) - 1)) - return tuple(r + [dlat, dlon, 15 * (len(self.data[self.id_longitude]) - 1)]) + return tuple(r + [dlat, dlon, ttime]) class TaxiAddFirstLastLen(Transformer): def __init__(self, k, stream): @@ -87,7 +88,7 @@ class TaxiAddFirstLastLen(Transformer): dtype=theano.config.floatX), numpy.array(at_least_k(self.k, data[self.id_longitude], True, True)[-self.k:], dtype=theano.config.floatX)) - input_time = (15 * (len(data[self.id_latitude]) - 1),) + input_time = (numpy.int32(15 * (len(data[self.id_latitude]) - 1)),) return data + first_k + last_k + input_time class TaxiAddDateTime(Transformer): @@ -101,7 +102,9 @@ class TaxiAddDateTime(Transformer): ts = data[self.id_timestamp] date = datetime.datetime.utcfromtimestamp(ts) yearweek = date.isocalendar()[1] - 1 - info = ((51 if yearweek == 52 else yearweek), date.weekday(), date.hour * 4 + date.minute / 15) + info = (numpy.int8(51 if yearweek == 52 else yearweek), + numpy.int8(date.weekday()), + numpy.int8(date.hour * 4 + date.minute / 15)) return data + info class TaxiExcludeTrips(Transformer): |