diff options
author | Étienne Simon <esimon@esimon.eu> | 2015-07-24 16:09:48 -0400 |
---|---|---|
committer | Étienne Simon <esimon@esimon.eu> | 2015-07-24 16:09:48 -0400 |
commit | 7dab7e47ce0e8c5ae996821794450a9ad3186cd3 (patch) | |
tree | e0babcc305696a6e6a67a52acecd300bfdf22cf0 /data | |
parent | 60e6bc64d8e3c6679a6e2a960513c656d481f0ed (diff) | |
download | taxi-7dab7e47ce0e8c5ae996821794450a9ad3186cd3.tar.gz taxi-7dab7e47ce0e8c5ae996821794450a9ad3186cd3.zip |
Fix memory network
Diffstat (limited to 'data')
-rw-r--r-- | data/transformers.py | 17 |
1 files changed, 2 insertions, 15 deletions
diff --git a/data/transformers.py b/data/transformers.py index b3a8486..88fdcf6 100644 --- a/data/transformers.py +++ b/data/transformers.py @@ -6,7 +6,7 @@ import theano import fuel from fuel.schemes import ConstantScheme -from fuel.transformers import Batch, Mapping, SortMapping, Transformer, Unpack +from fuel.transformers import Batch, Mapping, SortMapping, Transformer, Unpack, FilterSources import data @@ -22,20 +22,7 @@ def at_least_k(k, v, pad_at_begin, is_longitude): v = numpy.concatenate((v, numpy.full((k - len(v),), v[-1]))) return v - -class Select(Transformer): - produces_examples = True - - def __init__(self, data_stream, sources): - super(Select, self).__init__(data_stream) - self.ids = [data_stream.sources.index(source) for source in sources] - self.sources=sources - - def get_data(self, request=None): - if request is not None: - raise ValueError - data=next(self.child_epoch_iterator) - return [data[id] for id in self.ids] +Select = FilterSources class TaxiExcludeTrips(Transformer): produces_examples = True |