diff options
Diffstat (limited to 'transformers.py')
-rw-r--r-- | transformers.py | 30 |
1 files changed, 22 insertions, 8 deletions
diff --git a/transformers.py b/transformers.py index 57d7f14..c60d362 100644 --- a/transformers.py +++ b/transformers.py @@ -3,6 +3,17 @@ import numpy import theano import random +def at_least_k(k, pl, pad_at_begin): + if len(pl) == 0: + pl = [[ -8.61612, 41.1573]] + if len(pl) < k: + if pad_at_begin: + pl = [pl[0]] * (k - len(pl)) + pl + else: + pl = pl + [pl[-1]] * (k - len(pl)) + return pl + + class Select(Transformer): def __init__(self, data_stream, sources): super(Select, self).__init__(data_stream) @@ -18,31 +29,34 @@ class Select(Transformer): def add_first_k(k, stream): id_polyline=stream.sources.index('polyline') def first_k(x): - return (numpy.array(x[id_polyline][:k], dtype=theano.config.floatX).flatten(),) - stream = Filter(stream, lambda x: len(x[id_polyline])>=k) + pl = at_least_k(k, x[id_polyline], False) + return (numpy.array(pl[:k], dtype=theano.config.floatX).flatten(),) stream = Mapping(stream, first_k, ('first_k',)) return stream def add_random_k(k, stream): id_polyline=stream.sources.index('polyline') def random_k(x): - loc = random.randrange(len(x[id_polyline])-k+1) - return (numpy.array(x[id_polyline][loc:loc+k], dtype=theano.config.floatX).flatten(),) - stream = Filter(stream, lambda x: len(x[id_polyline])>=k) + pl = at_least_k(k, x[id_polyline], True) + loc = random.randrange(len(pl)-k+1) + return (numpy.array(pl[loc:loc+k], dtype=theano.config.floatX).flatten(),) stream = Mapping(stream, random_k, ('last_k',)) return stream def add_last_k(k, stream): id_polyline=stream.sources.index('polyline') def last_k(x): - return (numpy.array(x[id_polyline][-k:], dtype=theano.config.floatX).flatten(),) - stream = Filter(stream, lambda x: len(x[id_polyline])>=k) + pl = at_least_k(k, x[id_polyline], True) + return (numpy.array(pl[-k:], dtype=theano.config.floatX).flatten(),) stream = Mapping(stream, last_k, ('last_k',)) return stream def add_destination(stream): id_polyline=stream.sources.index('polyline') - return Mapping(stream, lambda x: (numpy.array(x[id_polyline][-1], dtype=theano.config.floatX),), ('destination',)) + return Mapping(stream, + lambda x: + (numpy.array(at_least_k(1, x[id_polyline], True)[-1], dtype=theano.config.floatX),), + ('destination',)) def concat_destination_xy(stream): id_dx=stream.sources.index('destination_x') |