aboutsummaryrefslogtreecommitdiff
path: root/data/transformers.py
diff options
context:
space:
mode:
authorÉtienne Simon <esimon@esimon.eu>2015-07-23 21:20:32 -0400
committerÉtienne Simon <esimon@esimon.eu>2015-07-23 21:20:32 -0400
commit13fc171f60ae1981c7ad4f2a302a8a85c29addc5 (patch)
treeabc29e6a877a2f971b0be9715c112d8eee8b0eb4 /data/transformers.py
parent8d31f9240056ec110cf63bde79d7661321d8ca7a (diff)
downloadtaxi-13fc171f60ae1981c7ad4f2a302a8a85c29addc5.tar.gz
taxi-13fc171f60ae1981c7ad4f2a302a8a85c29addc5.zip
Use new tvt dataset with option --tvt
Diffstat (limited to 'data/transformers.py')
-rw-r--r--data/transformers.py17
1 files changed, 11 insertions, 6 deletions
diff --git a/data/transformers.py b/data/transformers.py
index c2eb97e..b3a8486 100644
--- a/data/transformers.py
+++ b/data/transformers.py
@@ -70,7 +70,9 @@ class TaxiGenerateSplits(Transformer):
def __init__(self, data_stream, max_splits=-1):
super(TaxiGenerateSplits, self).__init__(data_stream)
- self.sources = data_stream.sources + ('destination_latitude', 'destination_longitude', 'travel_time')
+ self.sources = data_stream.sources
+ if not data.tvt:
+ self.sources += ('destination_latitude', 'destination_longitude', 'travel_time')
self.max_splits = max_splits
self.data = None
self.splits = []
@@ -100,12 +102,15 @@ class TaxiGenerateSplits(Transformer):
r[self.id_latitude] = numpy.array(r[self.id_latitude][:n], dtype=theano.config.floatX)
r[self.id_longitude] = numpy.array(r[self.id_longitude][:n], dtype=theano.config.floatX)
- dlat = numpy.float32(self.data[self.id_latitude][-1])
- dlon = numpy.float32(self.data[self.id_longitude][-1])
- ttime = numpy.int32(15 * (len(self.data[self.id_longitude]) - 1))
-
- return tuple(r + [dlat, dlon, ttime])
+ r = tuple(r)
+ if data.tvt:
+ return r
+ else:
+ dlat = numpy.float32(self.data[self.id_latitude][-1])
+ dlon = numpy.float32(self.data[self.id_longitude][-1])
+ ttime = numpy.int32(15 * (len(self.data[self.id_longitude]) - 1))
+ return r + (dlat, dlon, ttime)
class _taxi_add_first_last_len_helper(object):
def __init__(self, k, id_latitude, id_longitude):