aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ext_saveload.py4
-rwxr-xr-xtrain.py20
2 files changed, 14 insertions, 10 deletions
diff --git a/ext_saveload.py b/ext_saveload.py
index cc7c47a..059c5cf 100644
--- a/ext_saveload.py
+++ b/ext_saveload.py
@@ -15,14 +15,14 @@ class SaveLoadParams(SimpleExtension):
def do_save(self):
with open(self.path, 'w') as f:
logger.info('Saving parameters to %s...'%self.path)
- cPickle.dump(self.model.get_param_values(), f, protocol=cPickle.HIGHEST_PROTOCOL)
+ cPickle.dump(self.model.get_parameter_values(), f, protocol=cPickle.HIGHEST_PROTOCOL)
logger.info('Done saving.')
def do_load(self):
try:
with open(self.path, 'r') as f:
logger.info('Loading parameters from %s...'%self.path)
- self.model.set_param_values(cPickle.load(f))
+ self.model.set_parameter_values(cPickle.load(f))
except IOError:
pass
diff --git a/train.py b/train.py
index e3244ec..8260218 100755
--- a/train.py
+++ b/train.py
@@ -21,7 +21,7 @@ blocks.config.default_seed = 123
fuel.config.default_seed = 123
try:
- from blocks.extras.extensions.plotting import Plot
+ from blocks.extras.extensions.plot import Plot
use_plot = True
except ImportError:
use_plot = False
@@ -77,10 +77,10 @@ if __name__ == "__main__":
logger.info('# Parameter shapes:')
parameters_size = 0
- for key, value in cg.get_params().iteritems():
- logger.info(' %20s %s' % (value.get_value().shape, key))
+ for value in cg.parameters:
+ logger.info(' %20s %s' % (value.get_value().shape, value.name))
parameters_size += reduce(operator.mul, value.get_value().shape, 1)
- logger.info('Total number of parameters: %d in %d matrices' % (parameters_size, len(cg.get_params())))
+ logger.info('Total number of parameters: %d in %d matrices' % (parameters_size, len(cg.parameters)))
if hasattr(config, 'step_rule'):
step_rule = config.step_rule
@@ -97,9 +97,10 @@ if __name__ == "__main__":
RemoveNotFinite(),
step_rule
]),
- params=params)
+ parameters=params)
- plot_vars = [['valid_' + x.name for x in valid_monitored]]
+ plot_vars = [['valid_' + x.name for x in valid_monitored] +
+ ['train_' + x.name for x in valid_monitored]]
logger.info('Plotted variables: %s' % str(plot_vars))
dump_path = os.path.join('model_data', model_name) + '.pkl'
@@ -110,7 +111,7 @@ if __name__ == "__main__":
prefix='valid',
every_n_batches=1000),
Printing(every_n_batches=1000),
- # FinishAfter(every_n_batches=10),
+ FinishAfter(every_n_batches=10000000),
SaveLoadParams(dump_path, cg,
before_training=True, # before training -> load params
@@ -126,7 +127,10 @@ if __name__ == "__main__":
]
if use_plot:
- extensions.append(Plot(model_name, channels=plot_vars, every_n_batches=500))
+ extensions.append(Plot(model_name,
+ channels=plot_vars,
+ every_n_batches=500,
+ server_url='http://eos6:5006/'))
main_loop = MainLoop(
model=cg,