diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-27 17:27:43 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-27 17:27:43 -0400 |
commit | 9a60f6c4e39c09187710608a9e225b6024b34364 (patch) | |
tree | 92e43401b6c6d3982081a35ec680b82856ec00c0 /transformers.py | |
parent | 107b3798cca35472e158d94f36a0bd08f3fe1fe8 (diff) | |
download | taxi-9a60f6c4e39c09187710608a9e225b6024b34364.tar.gz taxi-9a60f6c4e39c09187710608a9e225b6024b34364.zip |
Add validation set ; fix lat/lon
Diffstat (limited to 'transformers.py')
-rw-r--r-- | transformers.py | 13 |
1 files changed, 13 insertions, 0 deletions
diff --git a/transformers.py b/transformers.py index b6f5e14..57d7f14 100644 --- a/transformers.py +++ b/transformers.py @@ -32,6 +32,19 @@ def add_random_k(k, stream): stream = Mapping(stream, random_k, ('last_k',)) return stream +def add_last_k(k, stream): + id_polyline=stream.sources.index('polyline') + def last_k(x): + return (numpy.array(x[id_polyline][-k:], dtype=theano.config.floatX).flatten(),) + stream = Filter(stream, lambda x: len(x[id_polyline])>=k) + stream = Mapping(stream, last_k, ('last_k',)) + return stream + def add_destination(stream): id_polyline=stream.sources.index('polyline') return Mapping(stream, lambda x: (numpy.array(x[id_polyline][-1], dtype=theano.config.floatX),), ('destination',)) + +def concat_destination_xy(stream): + id_dx=stream.sources.index('destination_x') + id_dy=stream.sources.index('destination_y') + return Mapping(stream, lambda x: (numpy.array([x[id_dx], x[id_dy]], dtype=theano.config.floatX),), ('destination',)) |