diff options
Diffstat (limited to 'data')
-rw-r--r-- | data/transformers.py | 17 |
1 files changed, 2 insertions, 15 deletions
diff --git a/data/transformers.py b/data/transformers.py index b3a8486..88fdcf6 100644 --- a/data/transformers.py +++ b/data/transformers.py @@ -6,7 +6,7 @@ import theano import fuel from fuel.schemes import ConstantScheme -from fuel.transformers import Batch, Mapping, SortMapping, Transformer, Unpack +from fuel.transformers import Batch, Mapping, SortMapping, Transformer, Unpack, FilterSources import data @@ -22,20 +22,7 @@ def at_least_k(k, v, pad_at_begin, is_longitude): v = numpy.concatenate((v, numpy.full((k - len(v),), v[-1]))) return v - -class Select(Transformer): - produces_examples = True - - def __init__(self, data_stream, sources): - super(Select, self).__init__(data_stream) - self.ids = [data_stream.sources.index(source) for source in sources] - self.sources=sources - - def get_data(self, request=None): - if request is not None: - raise ValueError - data=next(self.child_epoch_iterator) - return [data[id] for id in self.ids] +Select = FilterSources class TaxiExcludeTrips(Transformer): produces_examples = True |