aboutsummaryrefslogtreecommitdiff
path: root/transformers.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-04-27 17:27:43 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-04-27 17:27:43 -0400
commit9a60f6c4e39c09187710608a9e225b6024b34364 (patch)
tree92e43401b6c6d3982081a35ec680b82856ec00c0 /transformers.py
parent107b3798cca35472e158d94f36a0bd08f3fe1fe8 (diff)
downloadtaxi-9a60f6c4e39c09187710608a9e225b6024b34364.tar.gz
taxi-9a60f6c4e39c09187710608a9e225b6024b34364.zip
Add validation set ; fix lat/lon
Diffstat (limited to 'transformers.py')
-rw-r--r--transformers.py13
1 files changed, 13 insertions, 0 deletions
diff --git a/transformers.py b/transformers.py
index b6f5e14..57d7f14 100644
--- a/transformers.py
+++ b/transformers.py
@@ -32,6 +32,19 @@ def add_random_k(k, stream):
stream = Mapping(stream, random_k, ('last_k',))
return stream
+def add_last_k(k, stream):
+ id_polyline=stream.sources.index('polyline')
+ def last_k(x):
+ return (numpy.array(x[id_polyline][-k:], dtype=theano.config.floatX).flatten(),)
+ stream = Filter(stream, lambda x: len(x[id_polyline])>=k)
+ stream = Mapping(stream, last_k, ('last_k',))
+ return stream
+
def add_destination(stream):
id_polyline=stream.sources.index('polyline')
return Mapping(stream, lambda x: (numpy.array(x[id_polyline][-1], dtype=theano.config.floatX),), ('destination',))
+
+def concat_destination_xy(stream):
+ id_dx=stream.sources.index('destination_x')
+ id_dy=stream.sources.index('destination_y')
+ return Mapping(stream, lambda x: (numpy.array([x[id_dx], x[id_dy]], dtype=theano.config.floatX),), ('destination',))