From c3ec750f4a04e3117d658e8275dd3d91d2b0cbe4 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Thu, 11 Jun 2015 15:26:21 -0400 Subject: Add TaxiRemoveTestOnlyClients ; custom dumpmanager enabling multiprocessing --- data/cut.py | 1 + data/transformers.py | 10 ++++++++++ 2 files changed, 11 insertions(+) (limited to 'data') diff --git a/data/cut.py b/data/cut.py index 7853030..fc0b3f9 100644 --- a/data/cut.py +++ b/data/cut.py @@ -28,6 +28,7 @@ class TaxiTimeCutScheme(IterationScheme): c.execute('SELECT trip FROM trip_times WHERE begin >= ? AND begin <= ? AND end >= ?', (cut - 40000, cut, cut))] l = l + part + random.shuffle(l) return iter_(l) diff --git a/data/transformers.py b/data/transformers.py index 1b82dae..4814e9b 100644 --- a/data/transformers.py +++ b/data/transformers.py @@ -119,3 +119,13 @@ class TaxiExcludeTrips(Transformer): if not data[self.id_trip_id] in self.exclude: break return data +class TaxiRemoveTestOnlyClients(Transformer): + def __init__(self, stream): + super(TaxiRemoveTestOnlyClients, self).__init__(stream) + self.id_origin_call = stream.sources.index('origin_call') + def get_data(self, request=None): + if request is not None: raise ValueError + x = list(next(self.child_epoch_iterator)) + if x[self.id_origin_call] >= data.origin_call_train_size: + x[self.id_origin_call] = numpy.int32(0) + return tuple(x) -- cgit v1.2.3