diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-02 11:30:41 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-02 11:30:41 -0400 |
commit | b637e0bc7b123fe41ea2247ebb7aa311c88b81e0 (patch) | |
tree | 0ee61c124d67a3f6e928ecea2235082df8bd323b /train.py | |
parent | 3f3ab2bfe3ebfa266d433012be1c89c722d63352 (diff) | |
download | taxi-b637e0bc7b123fe41ea2247ebb7aa311c88b81e0.tar.gz taxi-b637e0bc7b123fe41ea2247ebb7aa311c88b81e0.zip |
Step rule & dropout params cleanup
Diffstat (limited to 'train.py')
-rwxr-xr-x | train.py | 7 |
1 files changed, 6 insertions, 1 deletions
@@ -132,12 +132,17 @@ if __name__ == "__main__": parameters_size += reduce(operator.mul, value.get_value().shape, 1) logger.info('Total number of parameters: %d in %d matrices' % (parameters_size, len(cg.get_params()))) + if hasattr(config, 'step_rule'): + step_rule = config.step_rule + else: + step_rule = AdaDelta() + params = cg.parameters algorithm = GradientDescent( cost=cost, step_rule=CompositeRule([ ElementwiseRemoveNotFinite(), - config.step_rule, + step_rule ]), params=params) |