aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ext_saveload.py33
-rw-r--r--ext_test.py66
-rwxr-xr-xtest.py54
-rwxr-xr-xtrain.py69
4 files changed, 108 insertions, 114 deletions
diff --git a/ext_saveload.py b/ext_saveload.py
new file mode 100644
index 0000000..cc7c47a
--- /dev/null
+++ b/ext_saveload.py
@@ -0,0 +1,33 @@
+import cPickle
+import logging
+
+from blocks.extensions import SimpleExtension
+
+logger = logging.getLogger(__name__)
+
+class SaveLoadParams(SimpleExtension):
+ def __init__(self, path, model, **kwargs):
+ super(SaveLoadParams, self).__init__(**kwargs)
+
+ self.path = path
+ self.model = model
+
+ def do_save(self):
+ with open(self.path, 'w') as f:
+ logger.info('Saving parameters to %s...'%self.path)
+ cPickle.dump(self.model.get_param_values(), f, protocol=cPickle.HIGHEST_PROTOCOL)
+ logger.info('Done saving.')
+
+ def do_load(self):
+ try:
+ with open(self.path, 'r') as f:
+ logger.info('Loading parameters from %s...'%self.path)
+ self.model.set_param_values(cPickle.load(f))
+ except IOError:
+ pass
+
+ def do(self, which_callback, *args):
+ if which_callback == 'before_training':
+ self.do_load()
+ else:
+ self.do_save()
diff --git a/ext_test.py b/ext_test.py
new file mode 100644
index 0000000..6a3fa0a
--- /dev/null
+++ b/ext_test.py
@@ -0,0 +1,66 @@
+#!/usr/bin/env python
+
+import logging
+import os
+import csv
+
+from blocks.model import Model
+from blocks.extensions import SimpleExtension
+
+logger = logging.getLogger(__name__)
+
+class RunOnTest(SimpleExtension):
+ def __init__(self, model_name, model, stream, **kwargs):
+ super(RunOnTest, self).__init__(**kwargs)
+
+ self.model_name = model_name
+
+ cg = Model(model.predict(**stream.inputs()))
+
+ self.inputs = cg.inputs
+ self.outputs = model.predict.outputs
+
+ req_vars_test = model.predict.inputs + ['trip_id']
+ self.test_stream = stream.test(req_vars_test)
+
+ self.function = cg.get_theano_function()
+
+ def do(self, which_callback, *args):
+ iter_no = repr(self.main_loop.log.status['iterations_done'])
+ if 'valid_destination_cost' in self.main_loop.log.current_row:
+ dvc = self.main_loop.log.current_row['valid_destination_cost']
+ else:
+ dvc = self.main_loop.log.current_row['valid_model_cost_cost']
+ if 'valid_time_cost' in self.main_loop.log.current_row:
+ tvc = self.main_loop.log.current_row['valid_time_cost']
+ else:
+ tvc = self.main_loop.log.current_row['valid_model_cost_cost']
+
+ if 'destination' in self.outputs:
+ dest_outname = 'test-dest-%s-it%s-cost%.3f.csv' % (self.model_name, iter_no, dvc)
+ dest_outfile = open(os.path.join('output', dest_outname), 'w')
+ dest_outcsv = csv.writer(dest_outfile)
+ dest_outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"])
+ logger.info("Generating output for test set: %s" % dest_outname)
+ if 'duration' in self.outputs:
+ time_outname = 'test-time-%s-it%s-cost%.3f.csv' % (self.model_name, iter_no, tvc)
+ time_outfile = open(os.path.join('output', time_outname), 'w')
+ time_outcsv = csv.writer(time_outfile)
+ time_outcsv.writerow(["TRIP_ID", "TRAVEL_TIME"])
+ logger.info("Generating output for test set: %s" % time_outname)
+
+ for d in self.test_stream.get_epoch_iterator(as_dict=True):
+ input_values = [d[k.name] for k in self.inputs]
+ output_values = self.function(*input_values)
+ if 'destination' in self.outputs:
+ destination = output_values[self.outputs.index('destination')]
+ dest_outcsv.writerow([d['trip_id'][0], destination[0, 0], destination[0, 1]])
+ if 'duration' in self.outputs:
+ duration = output_values[self.outputs.index('duration')]
+ time_outcsv.writerow([d['trip_id'][0], int(round(duration[0]))])
+
+ if 'destination' in self.outputs:
+ dest_outfile.close()
+ if 'duration' in self.outputs:
+ time_outfile.close()
+
diff --git a/test.py b/test.py
deleted file mode 100755
index 4925b27..0000000
--- a/test.py
+++ /dev/null
@@ -1,54 +0,0 @@
-#!/usr/bin/env python
-
-import cPickle
-import sys
-import os
-import importlib
-import csv
-
-from blocks.model import Model
-
-
-if __name__ == "__main__":
- if len(sys.argv) != 2:
- print >> sys.stderr, 'Usage: %s config' % sys.argv[0]
- sys.exit(1)
- model_name = sys.argv[1]
- config = importlib.import_module('.%s' % model_name, 'config')
- model_config = config.Model(config)
-
- stream = config.Stream(config)
- inputs = stream.inputs()
- outputs = model_config.predict.outputs
- req_vars_test = model_config.predict.inputs + ['trip_id']
- test_stream = stream.test(req_vars_test)
-
- model = Model(model_config.predict(**inputs))
- with open(os.path.join('model_data', "{}.pkl".format(model_name))) as f:
- parameters = cPickle.load(f)
- model.set_param_values(parameters)
-
- if 'destination' in outputs:
- dest_outfile = open(os.path.join('output', 'test-dest-output-%s.csv' % model_name), 'w')
- dest_outcsv = csv.writer(dest_outfile)
- dest_outcsv.writerow(["TRIP_ID", "LATITUDE", "LONGITUDE"])
- if 'duration' in outputs:
- time_outfile = open(os.path.join('output', 'test-time-output-%s.csv' % model_name), 'w')
- time_outcsv = csv.writer(time_outfile)
- time_outcsv.writerow(["TRIP_ID", "TRAVEL_TIME"])
-
- function = model.get_theano_function()
- for d in test_stream.get_epoch_iterator(as_dict=True):
- input_values = [d[k.name] for k in model.inputs]
- output_values = function(*input_values)
- if 'destination' in outputs:
- destination = output_values[outputs.index('destination')]
- dest_outcsv.writerow([d['trip_id'][0], destination[0, 0], destination[0, 1]])
- if 'duration' in outputs:
- duration = output_values[outputs.index('duration')]
- time_outcsv.writerow([d['trip_id'][0], int(round(duration[0]))])
-
- if 'destination' in outputs:
- dest_outfile.close()
- if 'duration' in outputs:
- time_outfile.close()
diff --git a/train.py b/train.py
index d40cb88..6d3f37b 100755
--- a/train.py
+++ b/train.py
@@ -1,6 +1,5 @@
#!/usr/bin/env python2
-import cPickle
import importlib
import logging
import operator
@@ -12,7 +11,7 @@ from theano import tensor
from blocks import roles
from blocks.algorithms import AdaDelta, CompositeRule, GradientDescent, RemoveNotFinite, StepRule, Momentum
-from blocks.extensions import Printing, FinishAfter, SimpleExtension
+from blocks.extensions import Printing, FinishAfter
from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
import blocks
blocks.config.default_seed = 123
@@ -28,66 +27,11 @@ from blocks.graph import ComputationGraph, apply_dropout, apply_noise
from blocks.main_loop import MainLoop
from blocks.model import Model
+from ext_saveload import SaveLoadParams
+from ext_test import RunOnTest
logger = logging.getLogger(__name__)
-
-class ElementwiseRemoveNotFinite(StepRule):
- """A step rule that replaces non-finite coefficients by zeros.
-
- Replaces non-finite elements (such as ``inf`` or ``NaN``) in a step
- (the parameter update of a single shared variable)
- with a scaled version of the parameters being updated instead.
-
- Parameters
- ----------
- scaler : float, optional
- The scaling applied to the parameter in case the step contains
- non-finite elements. Defaults to 0.1.
-
- Notes
- -----
- This trick was originally used in the GroundHog_ framework.
-
- .. _GroundHog: https://github.com/lisa-groundhog/GroundHog
-
- """
- def __init__(self, scaler=0.1):
- self.scaler = scaler
-
- def compute_step(self, param, previous_step):
- not_finite = tensor.isnan(previous_step) + tensor.isinf(previous_step)
- step = tensor.switch(not_finite, self.scaler * param, previous_step)
-
- return step, []
-
-class SaveLoadParams(SimpleExtension):
- def __init__(self, path, model, **kwargs):
- super(SaveLoadParams, self).__init__(**kwargs)
-
- self.path = path
- self.model = model
-
- def do_save(self):
- with open(self.path, 'w') as f:
- logger.info('Saving parameters to %s...'%self.path)
- cPickle.dump(self.model.get_param_values(), f, protocol=cPickle.HIGHEST_PROTOCOL)
- logger.info('Done saving.')
-
- def do_load(self):
- try:
- with open(self.path, 'r') as f:
- logger.info('Loading parameters from %s...'%self.path)
- self.model.set_param_values(cPickle.load(f))
- except IOError:
- pass
-
- def do(self, which_callback, *args):
- if which_callback == 'before_training':
- self.do_load()
- else:
- self.do_save()
-
if __name__ == "__main__":
if len(sys.argv) != 2:
print >> sys.stderr, 'Usage: %s config' % sys.argv[0]
@@ -143,7 +87,7 @@ if __name__ == "__main__":
algorithm = GradientDescent(
cost=cost,
step_rule=CompositeRule([
- ElementwiseRemoveNotFinite(),
+ RemoveNotFinite(),
step_rule
]),
params=params)
@@ -166,6 +110,11 @@ if __name__ == "__main__":
after_epoch=True, # after epoch -> save params
after_training=True, # after training -> save params
),
+
+ RunOnTest(model_name,
+ model,
+ stream,
+ every_n_batches=1000),
]
if use_plot: