aboutsummaryrefslogtreecommitdiff
path: root/data/transformers.py
diff options
context:
space:
mode:
Diffstat (limited to 'data/transformers.py')
-rw-r--r--data/transformers.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/data/transformers.py b/data/transformers.py
index 1b82dae..4814e9b 100644
--- a/data/transformers.py
+++ b/data/transformers.py
@@ -119,3 +119,13 @@ class TaxiExcludeTrips(Transformer):
if not data[self.id_trip_id] in self.exclude: break
return data
+class TaxiRemoveTestOnlyClients(Transformer):
+ def __init__(self, stream):
+ super(TaxiRemoveTestOnlyClients, self).__init__(stream)
+ self.id_origin_call = stream.sources.index('origin_call')
+ def get_data(self, request=None):
+ if request is not None: raise ValueError
+ x = list(next(self.child_epoch_iterator))
+ if x[self.id_origin_call] >= data.origin_call_train_size:
+ x[self.id_origin_call] = numpy.int32(0)
+ return tuple(x)