diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-04 13:20:50 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-04 13:20:50 -0400 |
commit | 929eaf8dd0233f8423b24b93b78c99fc9df65343 (patch) | |
tree | 9cd2982aa8ca60505ddb35c93f427fe63b0c8508 /transformers.py | |
parent | 9adfe767010e23823089b4db94cb4dc53cc3c12a (diff) | |
download | taxi-929eaf8dd0233f8423b24b93b78c99fc9df65343.tar.gz taxi-929eaf8dd0233f8423b24b93b78c99fc9df65343.zip |
Fixes
Diffstat (limited to 'transformers.py')
-rw-r--r-- | transformers.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/transformers.py b/transformers.py index 79e8327..876cee2 100644 --- a/transformers.py +++ b/transformers.py @@ -43,8 +43,8 @@ class TaxiGenerateSplits(Transformer): raise ValueError while self.isplit >= len(self.splits): self.data = next(self.child_epoch_iterator) - self.splits = range(len(self.data[self.id_polyline])) - random.shuffle_array(self.splits) + self.splits = range(len(self.data[self.id_longitude])) + random.shuffle(self.splits) if self.max_splits != -1 and len(self.splits) > self.max_splits: self.splits = self.splits[:self.max_splits] self.isplit = 0 @@ -55,11 +55,11 @@ class TaxiGenerateSplits(Transformer): r = list(self.data) - r[self.id_latitude] = r[self.id_latitude][:n] - r[self.id_longitude] = r[self.id_longitude][:n] + r[self.id_latitude] = numpy.array(r[self.id_latitude][:n], dtype=theano.config.floatX) + r[self.id_longitude] = numpy.array(r[self.id_longitude][:n], dtype=theano.config.floatX) - dlat = self.data[self.id_latitude][-1] - dlon = self.data[self.id_longitude][-1] + dlat = numpy.float32(self.data[self.id_latitude][-1]) + dlon = numpy.float32(self.data[self.id_longitude][-1]) return tuple(r + [dlat, dlon]) |