diff options
Diffstat (limited to 'model/joint_simple_mlp_tgtcls.py')
-rw-r--r-- | model/joint_simple_mlp_tgtcls.py | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/model/joint_simple_mlp_tgtcls.py b/model/joint_simple_mlp_tgtcls.py index 0a38e06..dd1242e 100644 --- a/model/joint_simple_mlp_tgtcls.py +++ b/model/joint_simple_mlp_tgtcls.py @@ -67,9 +67,13 @@ class Model(object): dest_cost.name = 'dest_cost' dest_hcost = error.hdist(dest_outputs, y_dest).mean() dest_hcost.name = 'dest_hcost' + time_cost = error.rmsle(time_outputs.flatten(), y_time.flatten()) time_cost.name = 'time_cost' - cost = dest_cost + time_cost + time_scost = config.time_cost_factor * time_cost + time_scost.name = 'time_scost' + + cost = dest_cost + time_scost cost.name = 'cost' # Initialization @@ -83,7 +87,7 @@ class Model(object): mlp.initialize() self.cost = cost - self.monitor = [cost, dest_cost, dest_hcost, time_cost] + self.monitor = [cost, dest_cost, dest_hcost, time_cost, time_scost] self.outputs = tensor.concatenate([dest_outputs, time_outputs[:, None]], axis=1) self.outputs.name = 'outputs' self.pred_vars = ['destination_longitude', 'destination_latitude', 'travel_time'] |