diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-06-11 15:26:21 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-06-11 15:26:21 -0400 |
commit | c3ec750f4a04e3117d658e8275dd3d91d2b0cbe4 (patch) | |
tree | 6fb3fcc9425948a13fdc421534b7a00574c1e595 /data/transformers.py | |
parent | 7e8dbac77ce712846954bdd5f4bfb62b6efaf7df (diff) | |
download | taxi-c3ec750f4a04e3117d658e8275dd3d91d2b0cbe4.tar.gz taxi-c3ec750f4a04e3117d658e8275dd3d91d2b0cbe4.zip |
Add TaxiRemoveTestOnlyClients ; custom dumpmanager enabling multiprocessing
Diffstat (limited to 'data/transformers.py')
-rw-r--r-- | data/transformers.py | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/data/transformers.py b/data/transformers.py index 1b82dae..4814e9b 100644 --- a/data/transformers.py +++ b/data/transformers.py @@ -119,3 +119,13 @@ class TaxiExcludeTrips(Transformer): if not data[self.id_trip_id] in self.exclude: break return data +class TaxiRemoveTestOnlyClients(Transformer): + def __init__(self, stream): + super(TaxiRemoveTestOnlyClients, self).__init__(stream) + self.id_origin_call = stream.sources.index('origin_call') + def get_data(self, request=None): + if request is not None: raise ValueError + x = list(next(self.child_epoch_iterator)) + if x[self.id_origin_call] >= data.origin_call_train_size: + x[self.id_origin_call] = numpy.int32(0) + return tuple(x) |