aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-07-02 12:49:03 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-07-02 12:49:18 -0400
commita4b190516d00428b1d8a81686a3291e5fa5f9865 (patch)
tree230f04cbb664d4f7138ca4f22839e6bf501b32be /train.py
parent859bee7196c78e9828d9182b5fea2ad2eab01f1d (diff)
downloadtaxi-a4b190516d00428b1d8a81686a3291e5fa5f9865.tar.gz
taxi-a4b190516d00428b1d8a81686a3291e5fa5f9865.zip
Make the testing into an extension run at each validation
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py69
1 files changed, 9 insertions, 60 deletions
diff --git a/train.py b/train.py
index d40cb88..6d3f37b 100755
--- a/train.py
+++ b/train.py
@@ -1,6 +1,5 @@
#!/usr/bin/env python2
-import cPickle
import importlib
import logging
import operator
@@ -12,7 +11,7 @@ from theano import tensor
from blocks import roles
from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite, StepRule, Momentum
-from blocks.extensions import Printing, FinishAfter, SimpleExtension
+from blocks.extensions import Printing, FinishAfter
from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
import blocks
blocks.config.default_seed = 123
@@ -28,66 +27,11 @@ from blocks.graph import ComputationGraph, apply_dropout, apply_noise
from blocks.main_loop import MainLoop
from blocks.model import Model
+from ext_saveload import SaveLoadParams
+from ext_test import RunOnTest
logger = logging.getLogger(__name__)
-
-class ElementwiseRemoveNotFinite(StepRule):
- """A step rule that replaces non-finite coefficients by zeros.
-
- Replaces non-finite elements (such as ``inf`` or ``NaN``) in a step
- (the parameter update of a single shared variable)
- with a scaled version of the parameters being updated instead.
-
- Parameters
- ----------
- scaler : float, optional
- The scaling applied to the parameter in case the step contains
- non-finite elements. Defaults to 0.1.
-
- Notes
- -----
- This trick was originally used in the GroundHog_ framework.
-
- .. _GroundHog: https://github.com/lisa-groundhog/GroundHog
-
- """
- def __init__(self, scaler=0.1):
- self.scaler = scaler
-
- def compute_step(self, param, previous_step):
- not_finite = tensor.isnan(previous_step) + tensor.isinf(previous_step)
- step = tensor.switch(not_finite, self.scaler * param, previous_step)
-
- return step, []
-
-class SaveLoadParams(SimpleExtension):
- def __init__(self, path, model, **kwargs):
- super(SaveLoadParams, self).__init__(**kwargs)
-
- self.path = path
- self.model = model
-
- def do_save(self):
- with open(self.path, 'w') as f:
- logger.info('Saving parameters to %s...'%self.path)
- cPickle.dump(self.model.get_param_values(), f, protocol=cPickle.HIGHEST_PROTOCOL)
- logger.info('Done saving.')
-
- def do_load(self):
- try:
- with open(self.path, 'r') as f:
- logger.info('Loading parameters from %s...'%self.path)
- self.model.set_param_values(cPickle.load(f))
- except IOError:
- pass
-
- def do(self, which_callback, *args):
- if which_callback == 'before_training':
- self.do_load()
- else:
- self.do_save()
-
if __name__ == "__main__":
if len(sys.argv) != 2:
print >> sys.stderr, 'Usage: %s config' % sys.argv[0]
@@ -143,7 +87,7 @@ if __name__ == "__main__":
algorithm = GradientDescent(
cost=cost,
step_rule=CompositeRule([
- ElementwiseRemoveNotFinite(),
+ RemoveNotFinite(),
step_rule
]),
params=params)
@@ -166,6 +110,11 @@ if __name__ == "__main__":
after_epoch=True, # after epoch -> save params
after_training=True, # after training -> save params
),
+
+ RunOnTest(model_name,
+ model,
+ stream,
+ every_n_batches=1000),
]
if use_plot: