diff options
Diffstat (limited to 'transformers.py')
-rw-r--r-- | transformers.py | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/transformers.py b/transformers.py index 92ee439..29b8094 100644 --- a/transformers.py +++ b/transformers.py @@ -1,4 +1,6 @@ from fuel.transformers import Transformer, Filter, Mapping +import numpy +import theano class Select(Transformer): def __init__(self, data_stream, sources): @@ -15,11 +17,12 @@ class Select(Transformer): def add_extremities(stream, k): id_polyline=stream.sources.index('polyline') def extremities(x): - return (x[id_polyline][:k], x[id_polyline][-k:]) + return (numpy.array(x[id_polyline][:k], dtype=theano.config.floatX).flatten(), + numpy.array(x[id_polyline][-k:], dtype=theano.config.floatX).flatten()) stream = Filter(stream, lambda x: len(x[id_polyline])>=k) stream = Mapping(stream, extremities, ('first_k', 'last_k')) return stream def add_destination(stream): id_polyline=stream.sources.index('polyline') - return Mapping(stream, lambda x: x[id_polyline][-1], ('destination',)) + return Mapping(stream, lambda x: (numpy.array(x[id_polyline][-1], dtype=theano.config.floatX),), ('destination',)) |