aboutsummaryrefslogtreecommitdiff
path: root/data/transformers.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/transformers.py')
-rw-r--r--data/transformers.py17
1 files changed, 11 insertions, 6 deletions
diff --git a/data/transformers.py b/data/transformers.py
index c2eb97e..b3a8486 100644
--- a/data/transformers.py
+++ b/data/transformers.py
@@ -70,7 +70,9 @@ class TaxiGenerateSplits(Transformer):
def __init__(self, data_stream, max_splits=-1):
super(TaxiGenerateSplits, self).__init__(data_stream)
- self.sources = data_stream.sources + ('destination_latitude', 'destination_longitude', 'travel_time')
+ self.sources = data_stream.sources
+ if not data.tvt:
+ self.sources += ('destination_latitude', 'destination_longitude', 'travel_time')
self.max_splits = max_splits
self.data = None
self.splits = []
@@ -100,12 +102,15 @@ class TaxiGenerateSplits(Transformer):
r[self.id_latitude] = numpy.array(r[self.id_latitude][:n], dtype=theano.config.floatX)
r[self.id_longitude] = numpy.array(r[self.id_longitude][:n], dtype=theano.config.floatX)
- dlat = numpy.float32(self.data[self.id_latitude][-1])
- dlon = numpy.float32(self.data[self.id_longitude][-1])
- ttime = numpy.int32(15 * (len(self.data[self.id_longitude]) - 1))
-
- return tuple(r + [dlat, dlon, ttime])
+ r = tuple(r)
+ if data.tvt:
+ return r
+ else:
+ dlat = numpy.float32(self.data[self.id_latitude][-1])
+ dlon = numpy.float32(self.data[self.id_longitude][-1])
+ ttime = numpy.int32(15 * (len(self.data[self.id_longitude]) - 1))
+ return r + (dlat, dlon, ttime)
class _taxi_add_first_last_len_helper(object):
def __init__(self, k, id_latitude, id_longitude):