diff options
author | Étienne Simon <esimon@esimon.eu> | 2015-05-18 16:22:00 -0400 |
---|---|---|
committer | Étienne Simon <esimon@esimon.eu> | 2015-05-18 16:22:00 -0400 |
commit | 6d946f29f7548c75e97f30c4356dbac200ee6cce (patch) | |
tree | 387e586c7ad0c1a0167d21451c9a8c877cf3ef0e /data | |
parent | 1e6d08b0c9ac5983691b182631c71e9d46ee71cc (diff) | |
download | taxi-6d946f29f7548c75e97f30c4356dbac200ee6cce.tar.gz taxi-6d946f29f7548c75e97f30c4356dbac200ee6cce.zip |
Refactor models, clean the code and separate training from testing.
Diffstat (limited to 'data')
-rw-r--r-- | data/transformers.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/data/transformers.py b/data/transformers.py index 57747fc..1b82dae 100644 --- a/data/transformers.py +++ b/data/transformers.py @@ -64,8 +64,9 @@ class TaxiGenerateSplits(Transformer): dlat = numpy.float32(self.data[self.id_latitude][-1]) dlon = numpy.float32(self.data[self.id_longitude][-1]) + ttime = numpy.int32(15 * (len(self.data[self.id_longitude]) - 1)) - return tuple(r + [dlat, dlon, 15 * (len(self.data[self.id_longitude]) - 1)]) + return tuple(r + [dlat, dlon, ttime]) class TaxiAddFirstLastLen(Transformer): def __init__(self, k, stream): @@ -87,7 +88,7 @@ class TaxiAddFirstLastLen(Transformer): dtype=theano.config.floatX), numpy.array(at_least_k(self.k, data[self.id_longitude], True, True)[-self.k:], dtype=theano.config.floatX)) - input_time = (15 * (len(data[self.id_latitude]) - 1),) + input_time = (numpy.int32(15 * (len(data[self.id_latitude]) - 1)),) return data + first_k + last_k + input_time class TaxiAddDateTime(Transformer): @@ -101,7 +102,9 @@ class TaxiAddDateTime(Transformer): ts = data[self.id_timestamp] date = datetime.datetime.utcfromtimestamp(ts) yearweek = date.isocalendar()[1] - 1 - info = ((51 if yearweek == 52 else yearweek), date.weekday(), date.hour * 4 + date.minute / 15) + info = (numpy.int8(51 if yearweek == 52 else yearweek), + numpy.int8(date.weekday()), + numpy.int8(date.hour * 4 + date.minute / 15)) return data + info class TaxiExcludeTrips(Transformer): |