diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-02 11:15:37 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-02 11:18:45 -0400 |
commit | 3f3ab2bfe3ebfa266d433012be1c89c722d63352 (patch) | |
tree | 589915018911ec364dccb4b897ab108913be464f /train.py | |
parent | 32b078f28add3d22529e55aeac6674d924e9b510 (diff) | |
download | taxi-3f3ab2bfe3ebfa266d433012be1c89c722d63352.tar.gz taxi-3f3ab2bfe3ebfa266d433012be1c89c722d63352.zip |
Unify parameters for joint_simple_tgtcls_111_cswdtx_bigger{,_dropout}
Diffstat (limited to 'train.py')
-rwxr-xr-x | train.py | 5 |
1 files changed, 4 insertions, 1 deletions
@@ -70,6 +70,7 @@ class SaveLoadParams(SimpleExtension): with open(self.path, 'w') as f: logger.info('Saving parameters to %s...'%self.path) cPickle.dump(self.model.get_param_values(), f, protocol=cPickle.HIGHEST_PROTOCOL) + logger.info('Done saving.') def do_load(self): try: @@ -153,8 +154,10 @@ if __name__ == "__main__": Printing(every_n_batches=1000), SaveLoadParams(dump_path, cg, - before_training=config.load_model, # before training -> load params + before_training=True, # before training -> load params every_n_batches=1000, # every N batches -> save params + after_epoch=True, # after epoch -> save params + after_training=True, # after training -> save params ), ] |