aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--config/memory_network_1.py5
-rw-r--r--data/cut.py5
-rw-r--r--model/memory_network.py2
3 files changed, 7 insertions, 5 deletions
diff --git a/config/memory_network_1.py b/config/memory_network_1.py
index 36d16ed..68e23bd 100644
--- a/config/memory_network_1.py
+++ b/config/memory_network_1.py
@@ -21,13 +21,13 @@ class MLPConfig(object):
prefix_encoder = MLPConfig()
prefix_encoder.dim_input = n_begin_end_pts * 2 * 2 + sum(x for (_, _, x) in dim_embeddings)
-prefix_encoder.dim_hidden = [50]
+prefix_encoder.dim_hidden = [100, 100, 100]
prefix_encoder.weights_init = IsotropicGaussian(0.01)
prefix_encoder.biases_init = Constant(0.001)
candidate_encoder = MLPConfig()
candidate_encoder.dim_input = n_begin_end_pts * 2 * 2 + sum(x for (_, _, x) in dim_embeddings)
-candidate_encoder.dim_hidden = [50]
+candidate_encoder.dim_hidden = [100, 100, 100]
candidate_encoder.weights_init = IsotropicGaussian(0.01)
candidate_encoder.biases_init = Constant(0.001)
@@ -38,6 +38,7 @@ batch_size = 32
valid_set = 'cuts/test_times_0'
max_splits = 1
+num_cuts = 1000
train_candidate_size = 1000
valid_candidate_size = 10000
diff --git a/data/cut.py b/data/cut.py
index fc0b3f9..6e1e4e5 100644
--- a/data/cut.py
+++ b/data/cut.py
@@ -11,14 +11,15 @@ last_time = 1404172787
class TaxiTimeCutScheme(IterationScheme):
- def __init__(self, dbfile=None, use_cuts=None):
+ def __init__(self, num_cuts=100, dbfile=None, use_cuts=None):
+ self.num_cuts = num_cuts
self.dbfile = os.path.join(data.path, 'time_index.db') if dbfile == None else dbfile
self.use_cuts = use_cuts
def get_request_iterator(self):
cuts = self.use_cuts
if cuts == None:
- cuts = [random.randrange(first_time, last_time) for _ in range(100)]
+ cuts = [random.randrange(first_time, last_time) for _ in range(self.num_cuts)]
l = []
with sqlite3.connect(self.dbfile) as db:
diff --git a/model/memory_network.py b/model/memory_network.py
index 92e83e2..5acbfe3 100644
--- a/model/memory_network.py
+++ b/model/memory_network.py
@@ -88,7 +88,7 @@ class Stream(object):
dataset = TaxiDataset('train')
- prefix_stream = DataStream(dataset, iteration_scheme=TaxiTimeCutScheme())
+ prefix_stream = DataStream(dataset, iteration_scheme=TaxiTimeCutScheme(self.config.num_cuts))
prefix_stream = transformers.TaxiExcludeTrips(prefix_stream, valid_trips_ids)
prefix_stream = transformers.TaxiGenerateSplits(prefix_stream, max_splits=self.config.max_splits)
prefix_stream = transformers.taxi_add_datetime(prefix_stream)