From 60e6bc64d8e3c6679a6e2a960513c656d481f0ed Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 24 Jul 2015 11:59:10 -0400 Subject: Add --progres option for ProgressBar --- train.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) (limited to 'train.py') diff --git a/train.py b/train.py index fb50a2f..b4ef774 100755 --- a/train.py +++ b/train.py @@ -14,7 +14,7 @@ import fuel from blocks import roles from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite, StepRule, Momentum -from blocks.extensions import Printing, FinishAfter +from blocks.extensions import Printing, FinishAfter, ProgressBar from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring blocks.config.default_seed = 123 @@ -37,8 +37,8 @@ from ext_test import RunOnTest logger = logging.getLogger(__name__) if __name__ == "__main__": - if len(sys.argv) < 2 or len(sys.argv) > 3: - print >> sys.stderr, 'Usage: %s [--tvt | --largevalid] config' % sys.argv[0] + if len(sys.argv) < 2 or len(sys.argv) > 4: + print >> sys.stderr, 'Usage: %s [--tvt | --largevalid] [--progress] config' % sys.argv[0] sys.exit(1) model_name = sys.argv[-1] config = importlib.import_module('.%s' % model_name, 'config') @@ -125,6 +125,9 @@ if __name__ == "__main__": stream, every_n_batches=10000), ] + + if '--progress' in sys.argv: + extensions.append(ProgressBar()) if use_plot: extensions.append(Plot(model_name, -- cgit v1.2.3