diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-25 14:38:06 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-25 14:38:06 -0400 |
commit | a67f85dd7a3d6ca69d9adf7cbac2cc796079d223 (patch) | |
tree | 0a8a9823f77e0889c6a445b6894b7d0036f0c500 | |
parent | a557b939eb104ec7e0df42193e014e3137eb70f8 (diff) | |
download | taxi-a67f85dd7a3d6ca69d9adf7cbac2cc796079d223.tar.gz taxi-a67f85dd7a3d6ca69d9adf7cbac2cc796079d223.zip |
Add monitor_freq config variable
-rwxr-xr-x | train.py | 15 |
1 files changed, 10 insertions, 5 deletions
@@ -106,16 +106,21 @@ if __name__ == "__main__": dump_path = os.path.join('model_data', model_name) + '.pkl' logger.info('Dump path: %s' % dump_path) - extensions=[TrainingDataMonitoring(monitored, prefix='train', every_n_batches=10000), + if hasattr(config, 'monitor_freq'): + monitor_freq = config.monitor_freq + else: + monitor_freq = 10000 + + extensions=[TrainingDataMonitoring(monitored, prefix='train', every_n_batches=monitor_freq), DataStreamMonitoring(valid_monitored, valid_stream, prefix='valid', - every_n_batches=10000), - Printing(every_n_batches=10000), + every_n_batches=monitor_freq), + Printing(every_n_batches=monitor_freq), FinishAfter(every_n_batches=10000000), SaveLoadParams(dump_path, cg, before_training=True, # before training -> load params - every_n_batches=10000, # every N batches -> save params + every_n_batches=monitor_freq,# every N batches -> save params after_epoch=True, # after epoch -> save params after_training=True, # after training -> save params ), @@ -123,7 +128,7 @@ if __name__ == "__main__": RunOnTest(model_name, model, stream, - every_n_batches=10000), + every_n_batches=monitor_freq), ] if '--progress' in sys.argv: |