aboutsummaryrefslogtreecommitdiff
path: root/data
diff options
context:
space:
mode:
authorÉtienne Simon <esimon@esimon.eu>2015-07-24 16:09:48 -0400
committerÉtienne Simon <esimon@esimon.eu>2015-07-24 16:09:48 -0400
commit7dab7e47ce0e8c5ae996821794450a9ad3186cd3 (patch)
treee0babcc305696a6e6a67a52acecd300bfdf22cf0 /data
parent60e6bc64d8e3c6679a6e2a960513c656d481f0ed (diff)
downloadtaxi-7dab7e47ce0e8c5ae996821794450a9ad3186cd3.tar.gz
taxi-7dab7e47ce0e8c5ae996821794450a9ad3186cd3.zip
Fix memory network
Diffstat (limited to 'data')
-rw-r--r--data/transformers.py17
1 files changed, 2 insertions, 15 deletions
diff --git a/data/transformers.py b/data/transformers.py
index b3a8486..88fdcf6 100644
--- a/data/transformers.py
+++ b/data/transformers.py
@@ -6,7 +6,7 @@ import theano
import fuel
from fuel.schemes import ConstantScheme
-from fuel.transformers import Batch, Mapping, SortMapping, Transformer, Unpack
+from fuel.transformers import Batch, Mapping, SortMapping, Transformer, Unpack, FilterSources
import data
@@ -22,20 +22,7 @@ def at_least_k(k, v, pad_at_begin, is_longitude):
v = numpy.concatenate((v, numpy.full((k - len(v),), v[-1])))
return v
-
-class Select(Transformer):
- produces_examples = True
-
- def __init__(self, data_stream, sources):
- super(Select, self).__init__(data_stream)
- self.ids = [data_stream.sources.index(source) for source in sources]
- self.sources=sources
-
- def get_data(self, request=None):
- if request is not None:
- raise ValueError
- data=next(self.child_epoch_iterator)
- return [data[id] for id in self.ids]
+Select = FilterSources
class TaxiExcludeTrips(Transformer):
produces_examples = True