diff options
Diffstat (limited to 'transformers.py')
-rw-r--r-- | transformers.py | 19 |
1 files changed, 14 insertions, 5 deletions
diff --git a/transformers.py b/transformers.py index 29b8094..b6f5e14 100644 --- a/transformers.py +++ b/transformers.py @@ -1,6 +1,7 @@ from fuel.transformers import Transformer, Filter, Mapping import numpy import theano +import random class Select(Transformer): def __init__(self, data_stream, sources): @@ -14,13 +15,21 @@ class Select(Transformer): data=next(self.child_epoch_iterator) return [data[id] for id in self.ids] -def add_extremities(stream, k): +def add_first_k(k, stream): id_polyline=stream.sources.index('polyline') - def extremities(x): - return (numpy.array(x[id_polyline][:k], dtype=theano.config.floatX).flatten(), - numpy.array(x[id_polyline][-k:], dtype=theano.config.floatX).flatten()) + def first_k(x): + return (numpy.array(x[id_polyline][:k], dtype=theano.config.floatX).flatten(),) stream = Filter(stream, lambda x: len(x[id_polyline])>=k) - stream = Mapping(stream, extremities, ('first_k', 'last_k')) + stream = Mapping(stream, first_k, ('first_k',)) + return stream + +def add_random_k(k, stream): + id_polyline=stream.sources.index('polyline') + def random_k(x): + loc = random.randrange(len(x[id_polyline])-k+1) + return (numpy.array(x[id_polyline][loc:loc+k], dtype=theano.config.floatX).flatten(),) + stream = Filter(stream, lambda x: len(x[id_polyline])>=k) + stream = Mapping(stream, random_k, ('last_k',)) return stream def add_destination(stream): |