diff options
Diffstat (limited to 'train.py')
-rwxr-xr-x | train.py | 18 |
1 files changed, 11 insertions, 7 deletions
@@ -9,12 +9,16 @@ from functools import reduce from theano import tensor +import blocks +import fuel + from blocks import roles from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite, StepRule, Momentum from blocks.extensions import Printing, FinishAfter from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring -import blocks + blocks.config.default_seed = 123 +fuel.config.default_seed = 123 try: from blocks.extras.extensions.plotting import Plot @@ -104,12 +108,12 @@ if __name__ == "__main__": every_n_batches=1000), Printing(every_n_batches=1000), - SaveLoadParams(dump_path, cg, - before_training=True, # before training -> load params - every_n_batches=1000, # every N batches -> save params - after_epoch=True, # after epoch -> save params - after_training=True, # after training -> save params - ), + # SaveLoadParams(dump_path, cg, + # before_training=True, # before training -> load params + # every_n_batches=1000, # every N batches -> save params + # after_epoch=True, # after epoch -> save params + # after_training=True, # after training -> save params + # ), RunOnTest(model_name, model, |