diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-29 16:59:30 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-29 16:59:30 -0400 |
commit | b07bf7612b07a08bd1298b87347889a71d390012 (patch) | |
tree | 6f6dafafa653e1163676444ae2009e6bdd89e0e7 /transformers.py | |
parent | 7604b28ff6e8293af383ae7328ea2285b3c9bba5 (diff) | |
download | taxi-b07bf7612b07a08bd1298b87347889a71d390012.tar.gz taxi-b07bf7612b07a08bd1298b87347889a71d390012.zip |
Fixes.
Diffstat (limited to 'transformers.py')
-rw-r--r-- | transformers.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/transformers.py b/transformers.py index 3473fd5..c5b8d87 100644 --- a/transformers.py +++ b/transformers.py @@ -2,10 +2,11 @@ from fuel.transformers import Transformer, Filter, Mapping import numpy import theano import random +import data def at_least_k(k, v, pad_at_begin, is_longitude): if len(v) == 0: - v = numpy.array([41.1573 if is_longitude else -8.61612], dtype=theano.config.floatX) + v = numpy.array([data.porto_center[1 if is_longitude else 0]], dtype=theano.config.floatX) if len(v) < k: if pad_at_begin: v = numpy.concatenate((numpy.full((k - len(v),), v[0]), v)) @@ -40,7 +41,7 @@ def add_random_k(k, stream): id_longitude = stream.sources.index('longitude') def random_k(x): lat = at_least_k(k, x[id_latitude], True, False) - lon = at_least_k(k, x[id_latitude], True, True) + lon = at_least_k(k, x[id_longitude], True, True) loc = random.randrange(len(lat)-k+1) return (numpy.array(lat[loc:loc+k], dtype=theano.config.floatX), numpy.array(lon[loc:loc+k], dtype=theano.config.floatX)) |