aboutsummaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py33
1 files changed, 29 insertions, 4 deletions
diff --git a/train.py b/train.py
index 65f9050..876fcba 100755
--- a/train.py
+++ b/train.py
@@ -14,7 +14,8 @@ from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNo
from blocks.extensions import Printing, FinishAfter
from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
from blocks.extensions.plot import Plot
-from blocks.extensions.saveload import Dump, LoadFromDump
+from blocks.extensions.saveload import LoadFromDump, Dump
+from blocks.dump import MainLoopDumpManager
from blocks.filter import VariableFilter
from blocks.graph import ComputationGraph, apply_dropout, apply_noise
from blocks.main_loop import MainLoop
@@ -53,6 +54,25 @@ class ElementwiseRemoveNotFinite(StepRule):
return step, []
+class CustomDumpManager(MainLoopDumpManager):
+ def dump(self, main_loop):
+ """Dumps the main loop to the root folder.
+ See :mod:`blocks.dump`.
+ Overwrites the old data if present.
+ """
+ if not os.path.exists(self.folder):
+ os.mkdir(self.folder)
+ self.dump_parameters(main_loop)
+ self.dump_log(main_loop)
+
+ def load(self):
+ return (self.load_parameters(), self.load_log())
+
+ def load_to(self, main_loop):
+ """Loads the dump from the root folder into the main loop."""
+ parameters, log = self.load()
+ main_loop.model.set_param_values(parameters)
+ main_loop.log = log
if __name__ == "__main__":
if len(sys.argv) != 2:
@@ -109,15 +129,20 @@ if __name__ == "__main__":
dump_path = os.path.join('model_data', model_name)
logger.info('Dump path: %s' % dump_path)
+
+ load_dump_ext = LoadFromDump(dump_path)
+ dump_ext = Dump(dump_path, every_n_batches=1000)
+ load_dump_ext.manager = CustomDumpManager(dump_path)
+ dump_ext.manager = CustomDumpManager(dump_path)
+
extensions=[TrainingDataMonitoring(monitored, prefix='train', every_n_batches=1000),
DataStreamMonitoring(monitored, valid_stream,
prefix='valid',
every_n_batches=1000),
Printing(every_n_batches=1000),
Plot(model_name, channels=plot_vars, every_n_batches=500),
- Dump(dump_path, every_n_batches=5000),
- LoadFromDump(dump_path),
- #FinishAfter(after_n_batches=2),
+ load_dump_ext,
+ dump_ext
]
main_loop = MainLoop(