aboutsummaryrefslogtreecommitdiff
path: root/data
diff options
context:
space:
mode:
Diffstat (limited to 'data')
-rw-r--r--data/cut.py1
-rw-r--r--data/transformers.py10
2 files changed, 11 insertions, 0 deletions
diff --git a/data/cut.py b/data/cut.py
index 7853030..fc0b3f9 100644
--- a/data/cut.py
+++ b/data/cut.py
@@ -28,6 +28,7 @@ class TaxiTimeCutScheme(IterationScheme):
c.execute('SELECT trip FROM trip_times WHERE begin >= ? AND begin <= ? AND end >= ?',
(cut - 40000, cut, cut))]
l = l + part
+ random.shuffle(l)
return iter_(l)
diff --git a/data/transformers.py b/data/transformers.py
index 1b82dae..4814e9b 100644
--- a/data/transformers.py
+++ b/data/transformers.py
@@ -119,3 +119,13 @@ class TaxiExcludeTrips(Transformer):
if not data[self.id_trip_id] in self.exclude: break
return data
+class TaxiRemoveTestOnlyClients(Transformer):
+ def __init__(self, stream):
+ super(TaxiRemoveTestOnlyClients, self).__init__(stream)
+ self.id_origin_call = stream.sources.index('origin_call')
+ def get_data(self, request=None):
+ if request is not None: raise ValueError
+ x = list(next(self.child_epoch_iterator))
+ if x[self.id_origin_call] >= data.origin_call_train_size:
+ x[self.id_origin_call] = numpy.int32(0)
+ return tuple(x)