aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-07-02 13:23:28 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-07-02 13:25:33 -0400
commit5096e0cdae167122d07b09cd207a04f28ea5c3f5 (patch)
treeba15ca59dce8b301330b8ef2f282099e5f6991a2 /train.py
parent98139f573eb179c8f5a06ba6c8d8883376814ccf (diff)
downloadtaxi-5096e0cdae167122d07b09cd207a04f28ea5c3f5.tar.gz
taxi-5096e0cdae167122d07b09cd207a04f28ea5c3f5.zip
Add random seed for TaxiGenerateSplits and for fuel
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py18
1 files changed, 11 insertions, 7 deletions
diff --git a/train.py b/train.py
index 6d3f37b..77dca53 100755
--- a/train.py
+++ b/train.py
@@ -9,12 +9,16 @@ from functools import reduce
from theano import tensor
+import blocks
+import fuel
+
from blocks import roles
from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite, StepRule, Momentum
from blocks.extensions import Printing, FinishAfter
from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
-import blocks
+
blocks.config.default_seed = 123
+fuel.config.default_seed = 123
try:
from blocks.extras.extensions.plotting import Plot
@@ -104,12 +108,12 @@ if __name__ == "__main__":
every_n_batches=1000),
Printing(every_n_batches=1000),
- SaveLoadParams(dump_path, cg,
- before_training=True, # before training -> load params
- every_n_batches=1000, # every N batches -> save params
- after_epoch=True, # after epoch -> save params
- after_training=True, # after training -> save params
- ),
+ # SaveLoadParams(dump_path, cg,
+ # before_training=True, # before training -> load params
+ # every_n_batches=1000, # every N batches -> save params
+ # after_epoch=True, # after epoch -> save params
+ # after_training=True, # after training -> save params
+ # ),
RunOnTest(model_name,
model,