aboutsummaryrefslogtreecommitdiff
path: root/transformers.py
diff options
context:
space:
mode:
Diffstat (limited to 'transformers.py')
-rw-r--r--transformers.py30
1 files changed, 22 insertions, 8 deletions
diff --git a/transformers.py b/transformers.py
index 57d7f14..c60d362 100644
--- a/transformers.py
+++ b/transformers.py
@@ -3,6 +3,17 @@ import numpy
import theano
import random
+def at_least_k(k, pl, pad_at_begin):
+ if len(pl) == 0:
+ pl = [[ -8.61612, 41.1573]]
+ if len(pl) < k:
+ if pad_at_begin:
+ pl = [pl[0]] * (k - len(pl)) + pl
+ else:
+ pl = pl + [pl[-1]] * (k - len(pl))
+ return pl
+
+
class Select(Transformer):
def __init__(self, data_stream, sources):
super(Select, self).__init__(data_stream)
@@ -18,31 +29,34 @@ class Select(Transformer):
def add_first_k(k, stream):
id_polyline=stream.sources.index('polyline')
def first_k(x):
- return (numpy.array(x[id_polyline][:k], dtype=theano.config.floatX).flatten(),)
- stream = Filter(stream, lambda x: len(x[id_polyline])>=k)
+ pl = at_least_k(k, x[id_polyline], False)
+ return (numpy.array(pl[:k], dtype=theano.config.floatX).flatten(),)
stream = Mapping(stream, first_k, ('first_k',))
return stream
def add_random_k(k, stream):
id_polyline=stream.sources.index('polyline')
def random_k(x):
- loc = random.randrange(len(x[id_polyline])-k+1)
- return (numpy.array(x[id_polyline][loc:loc+k], dtype=theano.config.floatX).flatten(),)
- stream = Filter(stream, lambda x: len(x[id_polyline])>=k)
+ pl = at_least_k(k, x[id_polyline], True)
+ loc = random.randrange(len(pl)-k+1)
+ return (numpy.array(pl[loc:loc+k], dtype=theano.config.floatX).flatten(),)
stream = Mapping(stream, random_k, ('last_k',))
return stream
def add_last_k(k, stream):
id_polyline=stream.sources.index('polyline')
def last_k(x):
- return (numpy.array(x[id_polyline][-k:], dtype=theano.config.floatX).flatten(),)
- stream = Filter(stream, lambda x: len(x[id_polyline])>=k)
+ pl = at_least_k(k, x[id_polyline], True)
+ return (numpy.array(pl[-k:], dtype=theano.config.floatX).flatten(),)
stream = Mapping(stream, last_k, ('last_k',))
return stream
def add_destination(stream):
id_polyline=stream.sources.index('polyline')
- return Mapping(stream, lambda x: (numpy.array(x[id_polyline][-1], dtype=theano.config.floatX),), ('destination',))
+ return Mapping(stream,
+ lambda x:
+ (numpy.array(at_least_k(1, x[id_polyline], True)[-1], dtype=theano.config.floatX),),
+ ('destination',))
def concat_destination_xy(stream):
id_dx=stream.sources.index('destination_x')