aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorÉtienne Simon <esimon@esimon.eu>2015-06-12 02:45:14 -0400
committerÉtienne Simon <esimon@esimon.eu>2015-06-12 02:45:14 -0400
commit557d0fa74de74b8dbd8618a972725a7a9926e452 (patch)
treea3b63399c9353fc5e29ac793b754fa373a00b415
parentdb0e57fc2a351cedef3b1270bf6047e9cae9fa9d (diff)
downloadtaxi-557d0fa74de74b8dbd8618a972725a7a9926e452.tar.gz
taxi-557d0fa74de74b8dbd8618a972725a7a9926e452.zip
Fix RNN validation
-rw-r--r--model/rnn.py27
-rwxr-xr-xtrain.py12
2 files changed, 31 insertions, 8 deletions
diff --git a/model/rnn.py b/model/rnn.py
index af35414..be17a95 100644
--- a/model/rnn.py
+++ b/model/rnn.py
@@ -93,8 +93,8 @@ class Model(Initializable):
def predict_inputs(self):
return self.inputs
- @application(outputs=['cost'])
- def cost(self, latitude, longitude, latitude_mask, **kwargs):
+ @application(outputs=['cost_matrix'])
+ def cost_matrix(self, latitude, longitude, latitude_mask, **kwargs):
latitude = latitude.T
longitude = longitude.T
latitude_mask = latitude_mask.T
@@ -106,13 +106,30 @@ class Model(Initializable):
axis=2)
target = target.repeat(latitude.shape[0], axis=0)
ce = error.erdist(target.reshape((-1, 2)), res.reshape((-1, 2)))
- ce *= latitude_mask.flatten()
- return ce.sum() / latitude_mask.sum()
+ ce = ce.reshape(latitude.shape)
+ return ce * latitude_mask
+
+ @cost_matrix.property('inputs')
+ def cost_matrix_inputs(self):
+ return self.inputs + ['destination_latitude', 'destination_longitude']
+
+ @application(outputs=['cost'])
+ def cost(self, latitude_mask, **kwargs):
+ return self.cost_matrix(latitude_mask=latitude_mask, **kwargs).sum() / latitude_mask.sum()
@cost.property('inputs')
def cost_inputs(self):
return self.inputs + ['destination_latitude', 'destination_longitude']
+ @application(outputs=['cost'])
+ def valid_cost(self, **kwargs):
+ # Only works when batch_size is 1.
+ return self.cost_matrix(**kwargs)[-1,0]
+
+ @valid_cost.property('inputs')
+ def valid_cost_inputs(self):
+ return self.inputs + ['destination_latitude', 'destination_longitude']
+
class Stream(object):
def __init__(self, config):
@@ -141,7 +158,7 @@ class Stream(object):
stream = transformers.add_destination(stream)
stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
- stream = Batch(stream, iteration_scheme=ConstantScheme(1000))
+ stream = Batch(stream, iteration_scheme=ConstantScheme(1))
stream = Padding(stream, mask_sources=['latitude', 'longitude'])
stream = transformers.Select(stream, req_vars)
return stream
diff --git a/train.py b/train.py
index 876fcba..94c00d2 100755
--- a/train.py
+++ b/train.py
@@ -100,6 +100,12 @@ if __name__ == "__main__":
cg = ComputationGraph(cost)
monitored = set([cost] + VariableFilter(roles=[roles.COST])(cg.variables))
+ valid_monitored = monitored
+ if hasattr(model, 'valid_cost'):
+ valid_cost = model.valid_cost(**inputs)
+ valid_cg = ComputationGraph(valid_cost)
+ valid_monitored = set([valid_cost] + VariableFilter(roles=[roles.COST])(valid_cg.variables))
+
if hasattr(config, 'dropout') and config.dropout < 1.0:
cg = apply_dropout(cg, config.dropout_inputs(cg), config.dropout)
if hasattr(config, 'noise') and config.noise > 0.0:
@@ -124,7 +130,7 @@ if __name__ == "__main__":
]),
params=params)
- plot_vars = [['valid_' + x.name for x in monitored]]
+ plot_vars = [['valid_' + x.name for x in valid_monitored]]
logger.info('Plotted variables: %s' % str(plot_vars))
dump_path = os.path.join('model_data', model_name)
@@ -136,13 +142,13 @@ if __name__ == "__main__":
dump_ext.manager = CustomDumpManager(dump_path)
extensions=[TrainingDataMonitoring(monitored, prefix='train', every_n_batches=1000),
- DataStreamMonitoring(monitored, valid_stream,
+ DataStreamMonitoring(valid_monitored, valid_stream,
prefix='valid',
every_n_batches=1000),
Printing(every_n_batches=1000),
Plot(model_name, channels=plot_vars, every_n_batches=500),
load_dump_ext,
- dump_ext
+ dump_ext
]
main_loop = MainLoop(