diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-24 11:59:10 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-24 11:59:10 -0400 |
commit | 60e6bc64d8e3c6679a6e2a960513c656d481f0ed (patch) | |
tree | e3ba1d5fad7e59c69155069d2b33639c9c202da7 /train.py | |
parent | ff49937eef024916ac4560ce0134d94006e9e2e5 (diff) | |
download | taxi-60e6bc64d8e3c6679a6e2a960513c656d481f0ed.tar.gz taxi-60e6bc64d8e3c6679a6e2a960513c656d481f0ed.zip |
Add --progres option for ProgressBar
Diffstat (limited to 'train.py')
-rwxr-xr-x | train.py | 9 |
1 files changed, 6 insertions, 3 deletions
@@ -14,7 +14,7 @@ import fuel from blocks import roles from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite, StepRule, Momentum -from blocks.extensions import Printing, FinishAfter +from blocks.extensions import Printing, FinishAfter, ProgressBar from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring blocks.config.default_seed = 123 @@ -37,8 +37,8 @@ from ext_test import RunOnTest logger = logging.getLogger(__name__) if __name__ == "__main__": - if len(sys.argv) < 2 or len(sys.argv) > 3: - print >> sys.stderr, 'Usage: %s [--tvt | --largevalid] config' % sys.argv[0] + if len(sys.argv) < 2 or len(sys.argv) > 4: + print >> sys.stderr, 'Usage: %s [--tvt | --largevalid] [--progress] config' % sys.argv[0] sys.exit(1) model_name = sys.argv[-1] config = importlib.import_module('.%s' % model_name, 'config') @@ -125,6 +125,9 @@ if __name__ == "__main__": stream, every_n_batches=10000), ] + + if '--progress' in sys.argv: + extensions.append(ProgressBar()) if use_plot: extensions.append(Plot(model_name, |