aboutsummaryrefslogtreecommitdiff
path: root/transformers.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-05-04 13:20:50 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-05-04 13:20:50 -0400
commit929eaf8dd0233f8423b24b93b78c99fc9df65343 (patch)
tree9cd2982aa8ca60505ddb35c93f427fe63b0c8508 /transformers.py
parent9adfe767010e23823089b4db94cb4dc53cc3c12a (diff)
downloadtaxi-929eaf8dd0233f8423b24b93b78c99fc9df65343.tar.gz
taxi-929eaf8dd0233f8423b24b93b78c99fc9df65343.zip
Fixes
Diffstat (limited to 'transformers.py')
-rw-r--r--transformers.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/transformers.py b/transformers.py
index 79e8327..876cee2 100644
--- a/transformers.py
+++ b/transformers.py
@@ -43,8 +43,8 @@ class TaxiGenerateSplits(Transformer):
raise ValueError
while self.isplit >= len(self.splits):
self.data = next(self.child_epoch_iterator)
- self.splits = range(len(self.data[self.id_polyline]))
- random.shuffle_array(self.splits)
+ self.splits = range(len(self.data[self.id_longitude]))
+ random.shuffle(self.splits)
if self.max_splits != -1 and len(self.splits) > self.max_splits:
self.splits = self.splits[:self.max_splits]
self.isplit = 0
@@ -55,11 +55,11 @@ class TaxiGenerateSplits(Transformer):
r = list(self.data)
- r[self.id_latitude] = r[self.id_latitude][:n]
- r[self.id_longitude] = r[self.id_longitude][:n]
+ r[self.id_latitude] = numpy.array(r[self.id_latitude][:n], dtype=theano.config.floatX)
+ r[self.id_longitude] = numpy.array(r[self.id_longitude][:n], dtype=theano.config.floatX)
- dlat = self.data[self.id_latitude][-1]
- dlon = self.data[self.id_longitude][-1]
+ dlat = numpy.float32(self.data[self.id_latitude][-1])
+ dlon = numpy.float32(self.data[self.id_longitude][-1])
return tuple(r + [dlat, dlon])