diff options
Diffstat (limited to 'transformers.py')
-rw-r--r-- | transformers.py | 13 |
1 files changed, 13 insertions, 0 deletions
diff --git a/transformers.py b/transformers.py index b6f5e14..57d7f14 100644 --- a/transformers.py +++ b/transformers.py @@ -32,6 +32,19 @@ def add_random_k(k, stream): stream = Mapping(stream, random_k, ('last_k',)) return stream +def add_last_k(k, stream): + id_polyline=stream.sources.index('polyline') + def last_k(x): + return (numpy.array(x[id_polyline][-k:], dtype=theano.config.floatX).flatten(),) + stream = Filter(stream, lambda x: len(x[id_polyline])>=k) + stream = Mapping(stream, last_k, ('last_k',)) + return stream + def add_destination(stream): id_polyline=stream.sources.index('polyline') return Mapping(stream, lambda x: (numpy.array(x[id_polyline][-1], dtype=theano.config.floatX),), ('destination',)) + +def concat_destination_xy(stream): + id_dx=stream.sources.index('destination_x') + id_dy=stream.sources.index('destination_y') + return Mapping(stream, lambda x: (numpy.array([x[id_dx], x[id_dy]], dtype=theano.config.floatX),), ('destination',)) |