From bd08e452093bba68fe2d79b1e9da76488b203720 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Simon?= Date: Mon, 22 Jun 2015 14:40:19 -0400 Subject: Update memory network --- data/cut.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'data/cut.py') diff --git a/data/cut.py b/data/cut.py index fc0b3f9..6e1e4e5 100644 --- a/data/cut.py +++ b/data/cut.py @@ -11,14 +11,15 @@ last_time = 1404172787 class TaxiTimeCutScheme(IterationScheme): - def __init__(self, dbfile=None, use_cuts=None): + def __init__(self, num_cuts=100, dbfile=None, use_cuts=None): + self.num_cuts = num_cuts self.dbfile = os.path.join(data.path, 'time_index.db') if dbfile == None else dbfile self.use_cuts = use_cuts def get_request_iterator(self): cuts = self.use_cuts if cuts == None: - cuts = [random.randrange(first_time, last_time) for _ in range(100)] + cuts = [random.randrange(first_time, last_time) for _ in range(self.num_cuts)] l = [] with sqlite3.connect(self.dbfile) as db: -- cgit v1.2.3