aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-07-28 09:48:31 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-07-28 09:48:31 -0400
commit7c15286b6dadd1adc1f7406faed402a4bfe770f3 (patch)
treead40a2ed76a64eb5402d93ac97dd0a493aa7b40f
parent1c8241ab7a5e933c8a3452b407a1be054467613b (diff)
downloadtaxi-7c15286b6dadd1adc1f7406faed402a4bfe770f3.tar.gz
taxi-7c15286b6dadd1adc1f7406faed402a4bfe770f3.zip
Memory net changes
-rw-r--r--config/memory_network_mlp_3_momentum.py4
-rwxr-xr-xtrain.py5
2 files changed, 5 insertions, 4 deletions
diff --git a/config/memory_network_mlp_3_momentum.py b/config/memory_network_mlp_3_momentum.py
index 241142e..83fbc96 100644
--- a/config/memory_network_mlp_3_momentum.py
+++ b/config/memory_network_mlp_3_momentum.py
@@ -41,9 +41,9 @@ candidate_encoder.dim_embeddings = dim_embeddings
representation_size = 500
representation_activation = Tanh
-normalize_representation = True
+normalize_representation = False
-step_rule = Momentum(learning_rate=0.01, momentum=0.9)
+step_rule = Momentum(learning_rate=0.001, momentum=0.9)
batch_size = 5000
# batch_sort_size = 20
diff --git a/train.py b/train.py
index 1d59bb5..27d9b77 100755
--- a/train.py
+++ b/train.py
@@ -114,14 +114,15 @@ if __name__ == "__main__":
extensions=[TrainingDataMonitoring(monitored, prefix='train', every_n_batches=monitor_freq),
DataStreamMonitoring(valid_monitored, valid_stream,
prefix='valid',
- every_n_batches=monitor_freq),
+ every_n_batches=monitor_freq,
+ after_epoch=False),
Printing(every_n_batches=monitor_freq),
FinishAfter(every_n_batches=10000000),
SaveLoadParams(dump_path, cg,
before_training=True, # before training -> load params
every_n_batches=monitor_freq,# every N batches -> save params
- after_epoch=True, # after epoch -> save params
+ after_epoch=False,
after_training=True, # after training -> save params
),