aboutsummaryrefslogtreecommitdiff
path: root/data/cut.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-05-21 17:05:07 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-05-21 17:05:07 -0400
commit222971123f2741ee3689092fb4396dac83a13338 (patch)
tree59acaae81aaf6f57474be05fb93de1eb00061689 /data/cut.py
parentf6d2c6fc47f93b158b70b5c0c9a45324041ca4d5 (diff)
downloadtaxi-222971123f2741ee3689092fb4396dac83a13338.tar.gz
taxi-222971123f2741ee3689092fb4396dac83a13338.zip
Implement cut-based iteration scheme (SLOW!!!)
Diffstat (limited to 'data/cut.py')
-rw-r--r--data/cut.py31
1 files changed, 31 insertions, 0 deletions
diff --git a/data/cut.py b/data/cut.py
new file mode 100644
index 0000000..1253434
--- /dev/null
+++ b/data/cut.py
@@ -0,0 +1,31 @@
+from fuel.schemes import IterationScheme
+import sqlite3
+import random
+import os
+from picklable_itertools import iter_
+
+import data
+
+first_time = 1372636853
+last_time = 1404172787
+
+
+class TaxiTimeCutScheme(IterationScheme):
+ def __init__(self, dbfile=None, use_cuts=None):
+ self.dbfile = os.path.join(data.path, 'time_index.db') if dbfile == None else dbfile
+ self.use_cuts = use_cuts
+
+ def get_request_iterator(self):
+ cuts = self.use_cuts
+ if cuts == None:
+ cuts = [random.randrange(first_time, last_time) for _ in range(100)]
+
+ l = []
+ with sqlite3.connect(self.dbfile) as db:
+ c = db.cursor()
+ for cut in cuts:
+ l = l + [i for (i,) in
+ c.execute('SELECT trip FROM trip_times WHERE begin <= ? AND end >= ?', (cut, cut))]
+
+ return iter_(l)
+