aboutsummaryrefslogtreecommitdiff
path: root/transformers.py
diff options
context:
space:
mode:
Diffstat (limited to 'transformers.py')
-rw-r--r--transformers.py13
1 files changed, 13 insertions, 0 deletions
diff --git a/transformers.py b/transformers.py
index b6f5e14..57d7f14 100644
--- a/transformers.py
+++ b/transformers.py
@@ -32,6 +32,19 @@ def add_random_k(k, stream):
stream = Mapping(stream, random_k, ('last_k',))
return stream
+def add_last_k(k, stream):
+ id_polyline=stream.sources.index('polyline')
+ def last_k(x):
+ return (numpy.array(x[id_polyline][-k:], dtype=theano.config.floatX).flatten(),)
+ stream = Filter(stream, lambda x: len(x[id_polyline])>=k)
+ stream = Mapping(stream, last_k, ('last_k',))
+ return stream
+
def add_destination(stream):
id_polyline=stream.sources.index('polyline')
return Mapping(stream, lambda x: (numpy.array(x[id_polyline][-1], dtype=theano.config.floatX),), ('destination',))
+
+def concat_destination_xy(stream):
+ id_dx=stream.sources.index('destination_x')
+ id_dy=stream.sources.index('destination_y')
+ return Mapping(stream, lambda x: (numpy.array([x[id_dx], x[id_dy]], dtype=theano.config.floatX),), ('destination',))