From bd08e452093bba68fe2d79b1e9da76488b203720 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Simon?= Date: Mon, 22 Jun 2015 14:40:19 -0400 Subject: Update memory network --- config/memory_network_1.py | 5 +++-- data/cut.py | 5 +++-- model/memory_network.py | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/config/memory_network_1.py b/config/memory_network_1.py index 36d16ed..68e23bd 100644 --- a/config/memory_network_1.py +++ b/config/memory_network_1.py @@ -21,13 +21,13 @@ class MLPConfig(object): prefix_encoder = MLPConfig() prefix_encoder.dim_input = n_begin_end_pts * 2 * 2 + sum(x for (_, _, x) in dim_embeddings) -prefix_encoder.dim_hidden = [50] +prefix_encoder.dim_hidden = [100, 100, 100] prefix_encoder.weights_init = IsotropicGaussian(0.01) prefix_encoder.biases_init = Constant(0.001) candidate_encoder = MLPConfig() candidate_encoder.dim_input = n_begin_end_pts * 2 * 2 + sum(x for (_, _, x) in dim_embeddings) -candidate_encoder.dim_hidden = [50] +candidate_encoder.dim_hidden = [100, 100, 100] candidate_encoder.weights_init = IsotropicGaussian(0.01) candidate_encoder.biases_init = Constant(0.001) @@ -38,6 +38,7 @@ batch_size = 32 valid_set = 'cuts/test_times_0' max_splits = 1 +num_cuts = 1000 train_candidate_size = 1000 valid_candidate_size = 10000 diff --git a/data/cut.py b/data/cut.py index fc0b3f9..6e1e4e5 100644 --- a/data/cut.py +++ b/data/cut.py @@ -11,14 +11,15 @@ last_time = 1404172787 class TaxiTimeCutScheme(IterationScheme): - def __init__(self, dbfile=None, use_cuts=None): + def __init__(self, num_cuts=100, dbfile=None, use_cuts=None): + self.num_cuts = num_cuts self.dbfile = os.path.join(data.path, 'time_index.db') if dbfile == None else dbfile self.use_cuts = use_cuts def get_request_iterator(self): cuts = self.use_cuts if cuts == None: - cuts = [random.randrange(first_time, last_time) for _ in range(100)] + cuts = [random.randrange(first_time, last_time) for _ in range(self.num_cuts)] l = [] with sqlite3.connect(self.dbfile) as db: diff --git a/model/memory_network.py b/model/memory_network.py index 92e83e2..5acbfe3 100644 --- a/model/memory_network.py +++ b/model/memory_network.py @@ -88,7 +88,7 @@ class Stream(object): dataset = TaxiDataset('train') - prefix_stream = DataStream(dataset, iteration_scheme=TaxiTimeCutScheme()) + prefix_stream = DataStream(dataset, iteration_scheme=TaxiTimeCutScheme(self.config.num_cuts)) prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, valid_trips_ids) prefix_stream = transformers.TaxiGenerateSplits(prefix_stream, max_splits=self.config.max_splits) prefix_stream = transformers.taxi_add_datetime(prefix_stream) -- cgit v1.2.3