aboutsummaryrefslogtreecommitdiff
path: root/config/memory_network_1.py
diff options
context:
space:
mode:
authorÉtienne Simon <esimon@esimon.eu>2015-06-22 14:40:19 -0400
committerÉtienne Simon <esimon@esimon.eu>2015-06-22 14:40:19 -0400
commitbd08e452093bba68fe2d79b1e9da76488b203720 (patch)
tree48d82d8cf883ad642c483e2cad5cc707dd8c4694 /config/memory_network_1.py
parentad5a03c6f60e5b2d543326bf8917b48e5b390b82 (diff)
downloadtaxi-bd08e452093bba68fe2d79b1e9da76488b203720.tar.gz
taxi-bd08e452093bba68fe2d79b1e9da76488b203720.zip
Update memory network
Diffstat (limited to 'config/memory_network_1.py')
-rw-r--r--config/memory_network_1.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/config/memory_network_1.py b/config/memory_network_1.py
index 36d16ed..68e23bd 100644
--- a/config/memory_network_1.py
+++ b/config/memory_network_1.py
@@ -21,13 +21,13 @@ class MLPConfig(object):
prefix_encoder = MLPConfig()
prefix_encoder.dim_input = n_begin_end_pts * 2 * 2 + sum(x for (_, _, x) in dim_embeddings)
-prefix_encoder.dim_hidden = [50]
+prefix_encoder.dim_hidden = [100, 100, 100]
prefix_encoder.weights_init = IsotropicGaussian(0.01)
prefix_encoder.biases_init = Constant(0.001)
candidate_encoder = MLPConfig()
candidate_encoder.dim_input = n_begin_end_pts * 2 * 2 + sum(x for (_, _, x) in dim_embeddings)
-candidate_encoder.dim_hidden = [50]
+candidate_encoder.dim_hidden = [100, 100, 100]
candidate_encoder.weights_init = IsotropicGaussian(0.01)
candidate_encoder.biases_init = Constant(0.001)
@@ -38,6 +38,7 @@ batch_size = 32
valid_set = 'cuts/test_times_0'
max_splits = 1
+num_cuts = 1000
train_candidate_size = 1000
valid_candidate_size = 10000