From 557d0fa74de74b8dbd8618a972725a7a9926e452 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Simon?= Date: Fri, 12 Jun 2015 02:45:14 -0400 Subject: Fix RNN validation --- model/rnn.py | 27 ++++++++++++++++++++++----- train.py | 12 +++++++++--- 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/model/rnn.py b/model/rnn.py index af35414..be17a95 100644 --- a/model/rnn.py +++ b/model/rnn.py @@ -93,8 +93,8 @@ class Model(Initializable): def predict_inputs(self): return self.inputs - @application(outputs=['cost']) - def cost(self, latitude, longitude, latitude_mask, **kwargs): + @application(outputs=['cost_matrix']) + def cost_matrix(self, latitude, longitude, latitude_mask, **kwargs): latitude = latitude.T longitude = longitude.T latitude_mask = latitude_mask.T @@ -106,13 +106,30 @@ class Model(Initializable): axis=2) target = target.repeat(latitude.shape[0], axis=0) ce = error.erdist(target.reshape((-1, 2)), res.reshape((-1, 2))) - ce *= latitude_mask.flatten() - return ce.sum() / latitude_mask.sum() + ce = ce.reshape(latitude.shape) + return ce * latitude_mask + + @cost_matrix.property('inputs') + def cost_matrix_inputs(self): + return self.inputs + ['destination_latitude', 'destination_longitude'] + + @application(outputs=['cost']) + def cost(self, latitude_mask, **kwargs): + return self.cost_matrix(latitude_mask=latitude_mask, **kwargs).sum() / latitude_mask.sum() @cost.property('inputs') def cost_inputs(self): return self.inputs + ['destination_latitude', 'destination_longitude'] + @application(outputs=['cost']) + def valid_cost(self, **kwargs): + # Only works when batch_size is 1. + return self.cost_matrix(**kwargs)[-1,0] + + @valid_cost.property('inputs') + def valid_cost_inputs(self): + return self.inputs + ['destination_latitude', 'destination_longitude'] + class Stream(object): def __init__(self, config): @@ -141,7 +158,7 @@ class Stream(object): stream = transformers.add_destination(stream) stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask'))) - stream = Batch(stream, iteration_scheme=ConstantScheme(1000)) + stream = Batch(stream, iteration_scheme=ConstantScheme(1)) stream = Padding(stream, mask_sources=['latitude', 'longitude']) stream = transformers.Select(stream, req_vars) return stream diff --git a/train.py b/train.py index 876fcba..94c00d2 100755 --- a/train.py +++ b/train.py @@ -100,6 +100,12 @@ if __name__ == "__main__": cg = ComputationGraph(cost) monitored = set([cost] + VariableFilter(roles=[roles.COST])(cg.variables)) + valid_monitored = monitored + if hasattr(model, 'valid_cost'): + valid_cost = model.valid_cost(**inputs) + valid_cg = ComputationGraph(valid_cost) + valid_monitored = set([valid_cost] + VariableFilter(roles=[roles.COST])(valid_cg.variables)) + if hasattr(config, 'dropout') and config.dropout < 1.0: cg = apply_dropout(cg, config.dropout_inputs(cg), config.dropout) if hasattr(config, 'noise') and config.noise > 0.0: @@ -124,7 +130,7 @@ if __name__ == "__main__": ]), params=params) - plot_vars = [['valid_' + x.name for x in monitored]] + plot_vars = [['valid_' + x.name for x in valid_monitored]] logger.info('Plotted variables: %s' % str(plot_vars)) dump_path = os.path.join('model_data', model_name) @@ -136,13 +142,13 @@ if __name__ == "__main__": dump_ext.manager = CustomDumpManager(dump_path) extensions=[TrainingDataMonitoring(monitored, prefix='train', every_n_batches=1000), - DataStreamMonitoring(monitored, valid_stream, + DataStreamMonitoring(valid_monitored, valid_stream, prefix='valid', every_n_batches=1000), Printing(every_n_batches=1000), Plot(model_name, channels=plot_vars, every_n_batches=500), load_dump_ext, - dump_ext + dump_ext ] main_loop = MainLoop( -- cgit v1.2.3