From 7dab7e47ce0e8c5ae996821794450a9ad3186cd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Simon?= Date: Fri, 24 Jul 2015 16:09:48 -0400 Subject: Fix memory network --- data/transformers.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) (limited to 'data/transformers.py') diff --git a/data/transformers.py b/data/transformers.py index b3a8486..88fdcf6 100644 --- a/data/transformers.py +++ b/data/transformers.py @@ -6,7 +6,7 @@ import theano import fuel from fuel.schemes import ConstantScheme -from fuel.transformers import Batch, Mapping, SortMapping, Transformer, Unpack +from fuel.transformers import Batch, Mapping, SortMapping, Transformer, Unpack, FilterSources import data @@ -22,20 +22,7 @@ def at_least_k(k, v, pad_at_begin, is_longitude): v = numpy.concatenate((v, numpy.full((k - len(v),), v[-1]))) return v - -class Select(Transformer): - produces_examples = True - - def __init__(self, data_stream, sources): - super(Select, self).__init__(data_stream) - self.ids = [data_stream.sources.index(source) for source in sources] - self.sources=sources - - def get_data(self, request=None): - if request is not None: - raise ValueError - data=next(self.child_epoch_iterator) - return [data[id] for id in self.ids] +Select = FilterSources class TaxiExcludeTrips(Transformer): produces_examples = True -- cgit v1.2.3