diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-27 13:08:56 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-27 13:08:56 -0400 |
commit | a25d4fb6e92f203183de2d89e8c467a6b14e1730 (patch) | |
tree | 6d448760e647572d52242f5726224cdf20e832ee /transformers.py | |
parent | ccd1245db7f6799ab4e1f45a8cead85ed67f1c72 (diff) | |
download | taxi-a25d4fb6e92f203183de2d89e8c467a6b14e1730.tar.gz taxi-a25d4fb6e92f203183de2d89e8c467a6b14e1730.zip |
Implement HDist, transformer that selects k at a random position.
Diffstat (limited to 'transformers.py')
-rw-r--r-- | transformers.py | 19 |
1 files changed, 14 insertions, 5 deletions
diff --git a/transformers.py b/transformers.py index 29b8094..b6f5e14 100644 --- a/transformers.py +++ b/transformers.py @@ -1,6 +1,7 @@ from fuel.transformers import Transformer, Filter, Mapping import numpy import theano +import random class Select(Transformer): def __init__(self, data_stream, sources): @@ -14,13 +15,21 @@ class Select(Transformer): data=next(self.child_epoch_iterator) return [data[id] for id in self.ids] -def add_extremities(stream, k): +def add_first_k(k, stream): id_polyline=stream.sources.index('polyline') - def extremities(x): - return (numpy.array(x[id_polyline][:k], dtype=theano.config.floatX).flatten(), - numpy.array(x[id_polyline][-k:], dtype=theano.config.floatX).flatten()) + def first_k(x): + return (numpy.array(x[id_polyline][:k], dtype=theano.config.floatX).flatten(),) stream = Filter(stream, lambda x: len(x[id_polyline])>=k) - stream = Mapping(stream, extremities, ('first_k', 'last_k')) + stream = Mapping(stream, first_k, ('first_k',)) + return stream + +def add_random_k(k, stream): + id_polyline=stream.sources.index('polyline') + def random_k(x): + loc = random.randrange(len(x[id_polyline])-k+1) + return (numpy.array(x[id_polyline][loc:loc+k], dtype=theano.config.floatX).flatten(),) + stream = Filter(stream, lambda x: len(x[id_polyline])>=k) + stream = Mapping(stream, random_k, ('last_k',)) return stream def add_destination(stream): |