From 6d946f29f7548c75e97f30c4356dbac200ee6cce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Simon?= Date: Mon, 18 May 2015 16:22:00 -0400 Subject: Refactor models, clean the code and separate training from testing. --- data/transformers.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) (limited to 'data') diff --git a/data/transformers.py b/data/transformers.py index 57747fc..1b82dae 100644 --- a/data/transformers.py +++ b/data/transformers.py @@ -64,8 +64,9 @@ class TaxiGenerateSplits(Transformer): dlat = numpy.float32(self.data[self.id_latitude][-1]) dlon = numpy.float32(self.data[self.id_longitude][-1]) + ttime = numpy.int32(15 * (len(self.data[self.id_longitude]) - 1)) - return tuple(r + [dlat, dlon, 15 * (len(self.data[self.id_longitude]) - 1)]) + return tuple(r + [dlat, dlon, ttime]) class TaxiAddFirstLastLen(Transformer): def __init__(self, k, stream): @@ -87,7 +88,7 @@ class TaxiAddFirstLastLen(Transformer): dtype=theano.config.floatX), numpy.array(at_least_k(self.k, data[self.id_longitude], True, True)[-self.k:], dtype=theano.config.floatX)) - input_time = (15 * (len(data[self.id_latitude]) - 1),) + input_time = (numpy.int32(15 * (len(data[self.id_latitude]) - 1)),) return data + first_k + last_k + input_time class TaxiAddDateTime(Transformer): @@ -101,7 +102,9 @@ class TaxiAddDateTime(Transformer): ts = data[self.id_timestamp] date = datetime.datetime.utcfromtimestamp(ts) yearweek = date.isocalendar()[1] - 1 - info = ((51 if yearweek == 52 else yearweek), date.weekday(), date.hour * 4 + date.minute / 15) + info = (numpy.int8(51 if yearweek == 52 else yearweek), + numpy.int8(date.weekday()), + numpy.int8(date.hour * 4 + date.minute / 15)) return data + info class TaxiExcludeTrips(Transformer): -- cgit v1.2.3