aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-05-21 10:46:05 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-05-21 10:46:05 -0400
commit9a779f7328a712a20dd393bdf32c6a84bf9fbe52 (patch)
tree1906e27febd8b0b76ff1de4bdec74e56cbac15e4 /train.py
parent9ff3d163609707c0138c0de731eec40449bd1815 (diff)
downloadtaxi-9a779f7328a712a20dd393bdf32c6a84bf9fbe52.tar.gz
taxi-9a779f7328a712a20dd393bdf32c6a84bf9fbe52.zip
Model changes
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/train.py b/train.py
index 96dd798..677ed45 100755
--- a/train.py
+++ b/train.py
@@ -73,7 +73,7 @@ def setup_test_stream(req_vars):
test = transformers.TaxiAddFirstLastLen(config.n_begin_end_pts, test)
test = transformers.Select(test, tuple(req_vars))
- test_stream = Batch(test, iteration_scheme=ConstantScheme(1000))
+ test_stream = Batch(test, iteration_scheme=ConstantScheme(1))
return test_stream
@@ -100,8 +100,8 @@ def main():
cost=cost,
step_rule=CompositeRule([
RemoveNotFinite(),
- #AdaDelta(decay_rate=0.95),
- Momentum(learning_rate=config.learning_rate, momentum=config.momentum),
+ AdaDelta(),
+ #Momentum(learning_rate=config.learning_rate, momentum=config.momentum),
]),
params=params)