diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-02 13:23:28 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-02 13:25:33 -0400 |
commit | 5096e0cdae167122d07b09cd207a04f28ea5c3f5 (patch) | |
tree | ba15ca59dce8b301330b8ef2f282099e5f6991a2 /train.py | |
parent | 98139f573eb179c8f5a06ba6c8d8883376814ccf (diff) | |
download | taxi-5096e0cdae167122d07b09cd207a04f28ea5c3f5.tar.gz taxi-5096e0cdae167122d07b09cd207a04f28ea5c3f5.zip |
Add random seed for TaxiGenerateSplits and for fuel
Diffstat (limited to 'train.py')
-rwxr-xr-x | train.py | 18 |
1 files changed, 11 insertions, 7 deletions
@@ -9,12 +9,16 @@ from functools import reduce from theano import tensor +import blocks +import fuel + from blocks import roles from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite, StepRule, Momentum from blocks.extensions import Printing, FinishAfter from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring -import blocks + blocks.config.default_seed = 123 +fuel.config.default_seed = 123 try: from blocks.extras.extensions.plotting import Plot @@ -104,12 +108,12 @@ if __name__ == "__main__": every_n_batches=1000), Printing(every_n_batches=1000), - SaveLoadParams(dump_path, cg, - before_training=True, # before training -> load params - every_n_batches=1000, # every N batches -> save params - after_epoch=True, # after epoch -> save params - after_training=True, # after training -> save params - ), + # SaveLoadParams(dump_path, cg, + # before_training=True, # before training -> load params + # every_n_batches=1000, # every N batches -> save params + # after_epoch=True, # after epoch -> save params + # after_training=True, # after training -> save params + # ), RunOnTest(model_name, model, |