aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-07-24 11:59:10 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-07-24 11:59:10 -0400
commit60e6bc64d8e3c6679a6e2a960513c656d481f0ed (patch)
treee3ba1d5fad7e59c69155069d2b33639c9c202da7 /train.py
parentff49937eef024916ac4560ce0134d94006e9e2e5 (diff)
downloadtaxi-60e6bc64d8e3c6679a6e2a960513c656d481f0ed.tar.gz
taxi-60e6bc64d8e3c6679a6e2a960513c656d481f0ed.zip
Add --progres option for ProgressBar
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/train.py b/train.py
index fb50a2f..b4ef774 100755
--- a/train.py
+++ b/train.py
@@ -14,7 +14,7 @@ import fuel
from blocks import roles
from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite, StepRule, Momentum
-from blocks.extensions import Printing, FinishAfter
+from blocks.extensions import Printing, FinishAfter, ProgressBar
from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
blocks.config.default_seed = 123
@@ -37,8 +37,8 @@ from ext_test import RunOnTest
logger = logging.getLogger(__name__)
if __name__ == "__main__":
- if len(sys.argv) < 2 or len(sys.argv) > 3:
- print >> sys.stderr, 'Usage: %s [--tvt | --largevalid] config' % sys.argv[0]
+ if len(sys.argv) < 2 or len(sys.argv) > 4:
+ print >> sys.stderr, 'Usage: %s [--tvt | --largevalid] [--progress] config' % sys.argv[0]
sys.exit(1)
model_name = sys.argv[-1]
config = importlib.import_module('.%s' % model_name, 'config')
@@ -125,6 +125,9 @@ if __name__ == "__main__":
stream,
every_n_batches=10000),
]
+
+ if '--progress' in sys.argv:
+ extensions.append(ProgressBar())
if use_plot:
extensions.append(Plot(model_name,