diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-27 15:02:27 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-27 15:02:27 -0400 |
commit | ff1502ff1b6a4192974f73347b365a5d3a0e1f20 (patch) | |
tree | c8bed8360b4243f94db53d853834d363f2365197 /data/transformers.py | |
parent | 8754c6a34689a56ff5baab032e38105903024a3a (diff) | |
download | taxi-ff1502ff1b6a4192974f73347b365a5d3a0e1f20.tar.gz taxi-ff1502ff1b6a4192974f73347b365a5d3a0e1f20.zip |
Bidir RNN with window
Diffstat (limited to 'data/transformers.py')
-rw-r--r-- | data/transformers.py | 33 |
1 files changed, 32 insertions, 1 deletions
diff --git a/data/transformers.py b/data/transformers.py index 88fdcf6..f0ed44a 100644 --- a/data/transformers.py +++ b/data/transformers.py @@ -142,7 +142,7 @@ class _balanced_batch_helper(object): def __init__(self, key): self.key = key def __call__(self, data): - return len(data[self.key]) + return data[self.key].shape[0] def balanced_batch(stream, key, batch_size, batch_sort_size): stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size * batch_sort_size)) @@ -176,3 +176,34 @@ class _add_destination_helper(object): def add_destination(stream): fun = _add_destination_helper(stream.sources.index('latitude'), stream.sources.index('longitude')) return Mapping(stream, fun, add_sources=('destination_latitude', 'destination_longitude')) + +class _window_helper(object): + def __init__(self, latitude, longitude, window_len): + self.latitude = latitude + self.longitude = longitude + self.window_len = window_len + def makewindow(self, x): + assert len(x.shape) == 1 + + if x.shape[0] < self.window_len: + x = numpy.concatenate( + [x, numpy.full((self.window_len - x.shape[0],), x[-1])]) + + y = [x[i: i+x.shape[0]-self.window_len+1][:, None] + for i in range(self.window_len)] + + return numpy.concatenate(y, axis=1) + + def __call__(self, data): + data = list(data) + data[self.latitude] = self.makewindow(data[self.latitude]) + data[self.longitude] = self.makewindow(data[self.longitude]) + return tuple(data) + + +def window(stream, window_len): + fun = _window_helper(stream.sources.index('latitude'), + stream.sources.index('longitude'), + window_len) + return Mapping(stream, fun) + |