aboutsummaryrefslogtreecommitdiff
path: root/transformers.py
diff options
context:
space:
mode:
Diffstat (limited to 'transformers.py')
-rw-r--r--transformers.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/transformers.py b/transformers.py
index 79e8327..876cee2 100644
--- a/transformers.py
+++ b/transformers.py
@@ -43,8 +43,8 @@ class TaxiGenerateSplits(Transformer):
raise ValueError
while self.isplit >= len(self.splits):
self.data = next(self.child_epoch_iterator)
- self.splits = range(len(self.data[self.id_polyline]))
- random.shuffle_array(self.splits)
+ self.splits = range(len(self.data[self.id_longitude]))
+ random.shuffle(self.splits)
if self.max_splits != -1 and len(self.splits) > self.max_splits:
self.splits = self.splits[:self.max_splits]
self.isplit = 0
@@ -55,11 +55,11 @@ class TaxiGenerateSplits(Transformer):
r = list(self.data)
- r[self.id_latitude] = r[self.id_latitude][:n]
- r[self.id_longitude] = r[self.id_longitude][:n]
+ r[self.id_latitude] = numpy.array(r[self.id_latitude][:n], dtype=theano.config.floatX)
+ r[self.id_longitude] = numpy.array(r[self.id_longitude][:n], dtype=theano.config.floatX)
- dlat = self.data[self.id_latitude][-1]
- dlon = self.data[self.id_longitude][-1]
+ dlat = numpy.float32(self.data[self.id_latitude][-1])
+ dlon = numpy.float32(self.data[self.id_longitude][-1])
return tuple(r + [dlat, dlon])