From ff1502ff1b6a4192974f73347b365a5d3a0e1f20 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Mon, 27 Jul 2015 15:02:27 -0400 Subject: Bidir RNN with window --- data/transformers.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) (limited to 'data') diff --git a/data/transformers.py b/data/transformers.py index 88fdcf6..f0ed44a 100644 --- a/data/transformers.py +++ b/data/transformers.py @@ -142,7 +142,7 @@ class _balanced_batch_helper(object): def __init__(self, key): self.key = key def __call__(self, data): - return len(data[self.key]) + return data[self.key].shape[0] def balanced_batch(stream, key, batch_size, batch_sort_size): stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size * batch_sort_size)) @@ -176,3 +176,34 @@ class _add_destination_helper(object): def add_destination(stream): fun = _add_destination_helper(stream.sources.index('latitude'), stream.sources.index('longitude')) return Mapping(stream, fun, add_sources=('destination_latitude', 'destination_longitude')) + +class _window_helper(object): + def __init__(self, latitude, longitude, window_len): + self.latitude = latitude + self.longitude = longitude + self.window_len = window_len + def makewindow(self, x): + assert len(x.shape) == 1 + + if x.shape[0] < self.window_len: + x = numpy.concatenate( + [x, numpy.full((self.window_len - x.shape[0],), x[-1])]) + + y = [x[i: i+x.shape[0]-self.window_len+1][:, None] + for i in range(self.window_len)] + + return numpy.concatenate(y, axis=1) + + def __call__(self, data): + data = list(data) + data[self.latitude] = self.makewindow(data[self.latitude]) + data[self.longitude] = self.makewindow(data[self.longitude]) + return tuple(data) + + +def window(stream, window_len): + fun = _window_helper(stream.sources.index('latitude'), + stream.sources.index('longitude'), + window_len) + return Mapping(stream, fun) + -- cgit v1.2.3