aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py20
1 files changed, 12 insertions, 8 deletions
diff --git a/train.py b/train.py
index e3244ec..8260218 100755
--- a/train.py
+++ b/train.py
@@ -21,7 +21,7 @@ blocks.config.default_seed = 123
fuel.config.default_seed = 123
try:
- from blocks.extras.extensions.plotting import Plot
+ from blocks.extras.extensions.plot import Plot
use_plot = True
except ImportError:
use_plot = False
@@ -77,10 +77,10 @@ if __name__ == "__main__":
logger.info('# Parameter shapes:')
parameters_size = 0
- for key, value in cg.get_params().iteritems():
- logger.info(' %20s %s' % (value.get_value().shape, key))
+ for value in cg.parameters:
+ logger.info(' %20s %s' % (value.get_value().shape, value.name))
parameters_size += reduce(operator.mul, value.get_value().shape, 1)
- logger.info('Total number of parameters: %d in %d matrices' % (parameters_size, len(cg.get_params())))
+ logger.info('Total number of parameters: %d in %d matrices' % (parameters_size, len(cg.parameters)))
if hasattr(config, 'step_rule'):
step_rule = config.step_rule
@@ -97,9 +97,10 @@ if __name__ == "__main__":
RemoveNotFinite(),
step_rule
]),
- params=params)
+ parameters=params)
- plot_vars = [['valid_' + x.name for x in valid_monitored]]
+ plot_vars = [['valid_' + x.name for x in valid_monitored] +
+ ['train_' + x.name for x in valid_monitored]]
logger.info('Plotted variables: %s' % str(plot_vars))
dump_path = os.path.join('model_data', model_name) + '.pkl'
@@ -110,7 +111,7 @@ if __name__ == "__main__":
prefix='valid',
every_n_batches=1000),
Printing(every_n_batches=1000),
- # FinishAfter(every_n_batches=10),
+ FinishAfter(every_n_batches=10000000),
SaveLoadParams(dump_path, cg,
before_training=True, # before training -> load params
@@ -126,7 +127,10 @@ if __name__ == "__main__":
]
if use_plot:
- extensions.append(Plot(model_name, channels=plot_vars, every_n_batches=500))
+ extensions.append(Plot(model_name,
+ channels=plot_vars,
+ every_n_batches=500,
+ server_url='http://eos6:5006/'))
main_loop = MainLoop(
model=cg,