aboutsummaryrefslogtreecommitdiff
path: root/data/transformers.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-07-27 15:02:27 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-07-27 15:02:27 -0400
commitff1502ff1b6a4192974f73347b365a5d3a0e1f20 (patch)
treec8bed8360b4243f94db53d853834d363f2365197 /data/transformers.py
parent8754c6a34689a56ff5baab032e38105903024a3a (diff)
downloadtaxi-ff1502ff1b6a4192974f73347b365a5d3a0e1f20.tar.gz
taxi-ff1502ff1b6a4192974f73347b365a5d3a0e1f20.zip
Bidir RNN with window
Diffstat (limited to 'data/transformers.py')
-rw-r--r--data/transformers.py33
1 files changed, 32 insertions, 1 deletions
diff --git a/data/transformers.py b/data/transformers.py
index 88fdcf6..f0ed44a 100644
--- a/data/transformers.py
+++ b/data/transformers.py
@@ -142,7 +142,7 @@ class _balanced_batch_helper(object):
def __init__(self, key):
self.key = key
def __call__(self, data):
- return len(data[self.key])
+ return data[self.key].shape[0]
def balanced_batch(stream, key, batch_size, batch_sort_size):
stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size * batch_sort_size))
@@ -176,3 +176,34 @@ class _add_destination_helper(object):
def add_destination(stream):
fun = _add_destination_helper(stream.sources.index('latitude'), stream.sources.index('longitude'))
return Mapping(stream, fun, add_sources=('destination_latitude', 'destination_longitude'))
+
+class _window_helper(object):
+ def __init__(self, latitude, longitude, window_len):
+ self.latitude = latitude
+ self.longitude = longitude
+ self.window_len = window_len
+ def makewindow(self, x):
+ assert len(x.shape) == 1
+
+ if x.shape[0] < self.window_len:
+ x = numpy.concatenate(
+ [x, numpy.full((self.window_len - x.shape[0],), x[-1])])
+
+ y = [x[i: i+x.shape[0]-self.window_len+1][:, None]
+ for i in range(self.window_len)]
+
+ return numpy.concatenate(y, axis=1)
+
+ def __call__(self, data):
+ data = list(data)
+ data[self.latitude] = self.makewindow(data[self.latitude])
+ data[self.longitude] = self.makewindow(data[self.longitude])
+ return tuple(data)
+
+
+def window(stream, window_len):
+ fun = _window_helper(stream.sources.index('latitude'),
+ stream.sources.index('longitude'),
+ window_len)
+ return Mapping(stream, fun)
+