aboutsummaryrefslogtreecommitdiff
path: root/data/transformers.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-06-11 15:26:21 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-06-11 15:26:21 -0400
commitc3ec750f4a04e3117d658e8275dd3d91d2b0cbe4 (patch)
tree6fb3fcc9425948a13fdc421534b7a00574c1e595 /data/transformers.py
parent7e8dbac77ce712846954bdd5f4bfb62b6efaf7df (diff)
downloadtaxi-c3ec750f4a04e3117d658e8275dd3d91d2b0cbe4.tar.gz
taxi-c3ec750f4a04e3117d658e8275dd3d91d2b0cbe4.zip
Add TaxiRemoveTestOnlyClients ; custom dumpmanager enabling multiprocessing
Diffstat (limited to 'data/transformers.py')
-rw-r--r--data/transformers.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/data/transformers.py b/data/transformers.py
index 1b82dae..4814e9b 100644
--- a/data/transformers.py
+++ b/data/transformers.py
@@ -119,3 +119,13 @@ class TaxiExcludeTrips(Transformer):
if not data[self.id_trip_id] in self.exclude: break
return data
+class TaxiRemoveTestOnlyClients(Transformer):
+ def __init__(self, stream):
+ super(TaxiRemoveTestOnlyClients, self).__init__(stream)
+ self.id_origin_call = stream.sources.index('origin_call')
+ def get_data(self, request=None):
+ if request is not None: raise ValueError
+ x = list(next(self.child_epoch_iterator))
+ if x[self.id_origin_call] >= data.origin_call_train_size:
+ x[self.id_origin_call] = numpy.int32(0)
+ return tuple(x)