aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-07-02 11:15:37 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-07-02 11:18:45 -0400
commit3f3ab2bfe3ebfa266d433012be1c89c722d63352 (patch)
tree589915018911ec364dccb4b897ab108913be464f /train.py
parent32b078f28add3d22529e55aeac6674d924e9b510 (diff)
downloadtaxi-3f3ab2bfe3ebfa266d433012be1c89c722d63352.tar.gz
taxi-3f3ab2bfe3ebfa266d433012be1c89c722d63352.zip
Unify parameters for joint_simple_tgtcls_111_cswdtx_bigger{,_dropout}
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/train.py b/train.py
index 17d5789..0d40f84 100755
--- a/train.py
+++ b/train.py
@@ -70,6 +70,7 @@ class SaveLoadParams(SimpleExtension):
with open(self.path, 'w') as f:
logger.info('Saving parameters to %s...'%self.path)
cPickle.dump(self.model.get_param_values(), f, protocol=cPickle.HIGHEST_PROTOCOL)
+ logger.info('Done saving.')
def do_load(self):
try:
@@ -153,8 +154,10 @@ if __name__ == "__main__":
Printing(every_n_batches=1000),
SaveLoadParams(dump_path, cg,
- before_training=config.load_model, # before training -> load params
+ before_training=True, # before training -> load params
every_n_batches=1000, # every N batches -> save params
+ after_epoch=True, # after epoch -> save params
+ after_training=True, # after training -> save params
),
]