aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorAdeB <adbrebs@gmail.com>2015-05-05 22:15:22 -0400
committerAdeB <adbrebs@gmail.com>2015-05-05 22:15:22 -0400
commitf4d3ee6449217535bdbe19ac9c5fdd825d71b0d3 (patch)
treeb2dfd7f6f914f5f9e4521634b9ffc4a2b0171fdd /train.py
parent54613c1f9cf510ca7a71d6619418f2247515aec6 (diff)
downloadtaxi-f4d3ee6449217535bdbe19ac9c5fdd825d71b0d3.tar.gz
taxi-f4d3ee6449217535bdbe19ac9c5fdd825d71b0d3.zip
New hyperparameters. Training error is monitored.
Diffstat (limited to 'train.py')
-rw-r--r--train.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/train.py b/train.py
index 4cbd526..f501cba 100644
--- a/train.py
+++ b/train.py
@@ -27,7 +27,7 @@ from blocks.graph import ComputationGraph
from blocks.main_loop import MainLoop
from blocks.extensions import Printing, FinishAfter
from blocks.extensions.saveload import Dump, LoadFromDump, Checkpoint
-from blocks.extensions.monitoring import DataStreamMonitoring
+from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
import data
import transformers
@@ -107,7 +107,8 @@ def main():
step_rule=Momentum(learning_rate=config.learning_rate, momentum=config.momentum),
params=params)
- extensions=[DataStreamMonitoring(model.monitor, valid_stream,
+ extensions=[TrainingDataMonitoring(model.monitor, prefix='train', every_n_batches=1000),
+ DataStreamMonitoring(model.monitor, valid_stream,
prefix='valid',
every_n_batches=1000),
Printing(every_n_batches=1000),