aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--data/hdf5.py2
-rwxr-xr-xtrain.py76
2 files changed, 44 insertions, 34 deletions
diff --git a/data/hdf5.py b/data/hdf5.py
index 6e2f9a4..f3d6da2 100644
--- a/data/hdf5.py
+++ b/data/hdf5.py
@@ -13,7 +13,7 @@ class TaxiDataset(H5PYDataset):
def __init__(self, which_set, filename='data.hdf5', **kwargs):
self.filename = filename
kwargs.setdefault('load_in_memory', True)
- super(TaxiDataset, self).__init__(self.data_path, which_set, **kwargs)
+ super(TaxiDataset, self).__init__(self.data_path, (which_set,), **kwargs)
@property
def data_path(self):
diff --git a/train.py b/train.py
index 94c00d2..0bc697b 100755
--- a/train.py
+++ b/train.py
@@ -1,4 +1,4 @@
-#!/usr/bin/env python
+#!/usr/bin/env python2
import importlib
import logging
@@ -11,11 +11,15 @@ from theano import tensor
from blocks import roles
from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite, StepRule
-from blocks.extensions import Printing, FinishAfter
+from blocks.extensions import Printing, FinishAfter, SimpleExtension
from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
-from blocks.extensions.plot import Plot
-from blocks.extensions.saveload import LoadFromDump, Dump
-from blocks.dump import MainLoopDumpManager
+
+try:
+ from blocks.extensions.plot import Plot
+ use_plot = True
+except ImportError:
+ use_plot = False
+
from blocks.filter import VariableFilter
from blocks.graph import ComputationGraph, apply_dropout, apply_noise
from blocks.main_loop import MainLoop
@@ -54,25 +58,31 @@ class ElementwiseRemoveNotFinite(StepRule):
return step, []
-class CustomDumpManager(MainLoopDumpManager):
- def dump(self, main_loop):
- """Dumps the main loop to the root folder.
- See :mod:`blocks.dump`.
- Overwrites the old data if present.
- """
- if not os.path.exists(self.folder):
- os.mkdir(self.folder)
- self.dump_parameters(main_loop)
- self.dump_log(main_loop)
-
- def load(self):
- return (self.load_parameters(), self.load_log())
-
- def load_to(self, main_loop):
- """Loads the dump from the root folder into the main loop."""
- parameters, log = self.load()
- main_loop.model.set_param_values(parameters)
- main_loop.log = log
+class SaveLoadParams(SimpleExtension):
+ def __init__(self, path, model, **kwargs):
+ super(SaveLoadParams, self).__init__(**kwargs)
+
+ self.path = path
+ self.model = model
+
+ def do_save(self):
+ with open(self.path, 'w') as f:
+ logger.info('Saving parameters to %s...'%self.path)
+ cPickle.dump(self.model.get_param_values(), f, protocol=cPickle.HIGHEST_PROTOCOL)
+
+ def do_load(self):
+ try:
+ with open(self.path, 'r') as f:
+ logger.info('Loading parameters from %s...'%self.path)
+ self.model.set_param_values(cPickle.load(f))
+ except IOError:
+ pass
+
+ def do(self, which_callback, *args):
+ if which_callback == 'before_training':
+ self.do_load()
+ else:
+ self.do_save()
if __name__ == "__main__":
if len(sys.argv) != 2:
@@ -133,23 +143,23 @@ if __name__ == "__main__":
plot_vars = [['valid_' + x.name for x in valid_monitored]]
logger.info('Plotted variables: %s' % str(plot_vars))
- dump_path = os.path.join('model_data', model_name)
+ dump_path = os.path.join('model_data', model_name) + '.pkl'
logger.info('Dump path: %s' % dump_path)
- load_dump_ext = LoadFromDump(dump_path)
- dump_ext = Dump(dump_path, every_n_batches=1000)
- load_dump_ext.manager = CustomDumpManager(dump_path)
- dump_ext.manager = CustomDumpManager(dump_path)
-
extensions=[TrainingDataMonitoring(monitored, prefix='train', every_n_batches=1000),
DataStreamMonitoring(valid_monitored, valid_stream,
prefix='valid',
every_n_batches=1000),
Printing(every_n_batches=1000),
- Plot(model_name, channels=plot_vars, every_n_batches=500),
- load_dump_ext,
- dump_ext
+
+ SaveLoadParams(dump_path, cg,
+ before_training=True, # before training -> load params
+ every_n_batches=1000, # every N batches -> save params
+ ),
]
+
+ if use_plot:
+ extensions.append(Plot(model_name, channels=plot_vars, every_n_batches=500))
main_loop = MainLoop(
model=cg,