diff options
author | Étienne Simon <esimon@esimon.eu> | 2015-07-23 21:20:32 -0400 |
---|---|---|
committer | Étienne Simon <esimon@esimon.eu> | 2015-07-23 21:20:32 -0400 |
commit | 13fc171f60ae1981c7ad4f2a302a8a85c29addc5 (patch) | |
tree | abc29e6a877a2f971b0be9715c112d8eee8b0eb4 /data/transformers.py | |
parent | 8d31f9240056ec110cf63bde79d7661321d8ca7a (diff) | |
download | taxi-13fc171f60ae1981c7ad4f2a302a8a85c29addc5.tar.gz taxi-13fc171f60ae1981c7ad4f2a302a8a85c29addc5.zip |
Use new tvt dataset with option --tvt
Diffstat (limited to 'data/transformers.py')
-rw-r--r-- | data/transformers.py | 17 |
1 files changed, 11 insertions, 6 deletions
diff --git a/data/transformers.py b/data/transformers.py index c2eb97e..b3a8486 100644 --- a/data/transformers.py +++ b/data/transformers.py @@ -70,7 +70,9 @@ class TaxiGenerateSplits(Transformer): def __init__(self, data_stream, max_splits=-1): super(TaxiGenerateSplits, self).__init__(data_stream) - self.sources = data_stream.sources + ('destination_latitude', 'destination_longitude', 'travel_time') + self.sources = data_stream.sources + if not data.tvt: + self.sources += ('destination_latitude', 'destination_longitude', 'travel_time') self.max_splits = max_splits self.data = None self.splits = [] @@ -100,12 +102,15 @@ class TaxiGenerateSplits(Transformer): r[self.id_latitude] = numpy.array(r[self.id_latitude][:n], dtype=theano.config.floatX) r[self.id_longitude] = numpy.array(r[self.id_longitude][:n], dtype=theano.config.floatX) - dlat = numpy.float32(self.data[self.id_latitude][-1]) - dlon = numpy.float32(self.data[self.id_longitude][-1]) - ttime = numpy.int32(15 * (len(self.data[self.id_longitude]) - 1)) - - return tuple(r + [dlat, dlon, ttime]) + r = tuple(r) + if data.tvt: + return r + else: + dlat = numpy.float32(self.data[self.id_latitude][-1]) + dlon = numpy.float32(self.data[self.id_longitude][-1]) + ttime = numpy.int32(15 * (len(self.data[self.id_longitude]) - 1)) + return r + (dlat, dlon, ttime) class _taxi_add_first_last_len_helper(object): def __init__(self, k, id_latitude, id_longitude): |