summaryrefslogtreecommitdiff
path: root/abstract/abs_interp_dynpart.ml
diff options
context:
space:
mode:
Diffstat (limited to 'abstract/abs_interp_dynpart.ml')
-rw-r--r--abstract/abs_interp_dynpart.ml176
1 files changed, 124 insertions, 52 deletions
diff --git a/abstract/abs_interp_dynpart.ml b/abstract/abs_interp_dynpart.ml
index 006854e..bc0755b 100644
--- a/abstract/abs_interp_dynpart.ml
+++ b/abstract/abs_interp_dynpart.ml
@@ -48,7 +48,7 @@ end = struct
(* program expressions *)
f : bool_expr;
- guarantees : (id * bool_expr) list;
+ guarantees : (id * bool_expr * id) list;
(* data *)
loc : (int, location) Hashtbl.t;
@@ -168,7 +168,7 @@ end = struct
Format.printf "Complete formula:@.%a@.@." Formula_printer.print_expr f;
let facts = get_root_true f in
- let f, rp, _ = List.fold_left
+ let f, rp, repls = List.fold_left
(fun (f, (rp : rooted_prog), repls) eq ->
match eq with
| BEnumCons(E_EQ, a, EIdent b)
@@ -190,7 +190,7 @@ end = struct
let rp =
{ rp with all_vars =
List.filter (fun (_, id, _) -> id <> repl) rp.all_vars } in
- let repls = (repl, keep)::
+ let repls = [repl, keep; "L"^repl, "L"^keep]@
(List.map (fun (r, k) -> r, if k = repl then keep else k) repls) in
f, rp, repls
end
@@ -203,8 +203,11 @@ end = struct
Formula_printer.print_expr f;
let guarantees = Transform.guarantees_of_prog rp in
+ let guarantees = List.map
+ (fun (id, f, v) -> id, formula_replace_evars repls f, v)
+ guarantees in
Format.printf "Guarantees:@.";
- List.iter (fun (id, f) ->
+ List.iter (fun (id, f, _) ->
Format.printf " %s: %a@." id Formula_printer.print_expr f)
guarantees;
Format.printf "@.";
@@ -266,7 +269,7 @@ end = struct
*)
let rec ternary_conds = function
| BAnd(a, b) -> ternary_conds a @ ternary_conds b
- | BTernary(c, a, b) -> [c]
+ | BTernary(c, a, b) as x -> [c, x]
| _ -> []
(*
@@ -328,7 +331,7 @@ end = struct
(fun id loc ->
Format.printf "@.";
Format.printf "q%d (depth = %d):@. D: @[<v 2>%a@]@." id loc.depth print_v loc.def;
- Format.printf " F: (%a)@." Formula_printer.print_expr loc.f;
+ (*Format.printf " F: (%a)@." Formula_printer.print_expr loc.f;*)
Format.printf " V: %a@." print_v loc.v;
Format.printf " -> @[<hov>[%a]@]@."
(print_list (fun fmt i -> Format.fprintf fmt "q%d" i) ", ") loc.out_t;
@@ -338,20 +341,30 @@ end = struct
let dump_graphwiz_trans_graph e file =
let o = open_out file in
let fmt = Format.formatter_of_out_channel o in
- Format.fprintf fmt "digraph G{@[<v 4>@ ";
+ Format.fprintf fmt "digraph G{@.";
Hashtbl.iter
(fun id loc ->
if loc.is_init then
- Format.fprintf fmt "q%d [shape=doublecircle, label=\"q%d [%a]\"];@ "
+ Format.fprintf fmt " q%d [shape=doublecircle, label=\"q%d [%a]\"];@."
id id (print_list Format.pp_print_string ", ") loc.violate_g
else
- Format.fprintf fmt "q%d [label=\"q%d [%a]\"];@ "
+ Format.fprintf fmt " q%d [label=\"q%d [%a]\"];@."
id id (print_list Format.pp_print_string ", ") loc.violate_g;
- List.iter (fun v -> Format.fprintf fmt "q%d -> q%d;@ " id v) loc.out_t)
+ let n1 = List.length loc.violate_g in
+ List.iter
+ (fun v ->
+ let n2 = List.length (Hashtbl.find e.loc v).violate_g in
+ let c, w =
+ if n2 > n1 then "#770000", 1
+ else "black", 2
+ in
+ Format.fprintf fmt " q%d -> q%d [color = \"%s\", weight = %d];@."
+ id v c w)
+ loc.out_t)
e.loc;
- Format.fprintf fmt "@]@.}@.";
+ Format.fprintf fmt "}@.";
close_out o
@@ -376,7 +389,7 @@ end = struct
loc.v <- bottom e)
e.loc;
- print_locs_defs e;
+ (*print_locs_defs e;*)
(* Iterate *)
let it_counter = ref 0 in
@@ -413,7 +426,7 @@ end = struct
else iter (n+1) j
in
let y = iter 0 start in
- let z = fix eq_v f y in
+ let z = f y in
let u = pass_cycle e.ve z in
if e.opt.verbose_ci then
@@ -463,15 +476,17 @@ end = struct
Hashtbl.iter
(fun _ loc ->
let verif, violate = List.partition
- (fun (_, f) ->
+ (fun (_, f, _) ->
is_bot (apply_cl loc.v (conslist_of_f f)))
e.guarantees
in
- loc.verif_g <- List.map fst verif;
- loc.violate_g <- List.map fst violate)
+ loc.verif_g <- List.map (fun (a, b, c) -> a) verif;
+ loc.violate_g <- List.map (fun (a, b, c) -> a) violate)
e.loc;
- print_locs e
+ print_locs e;
+
+ ()
@@ -491,8 +506,8 @@ end = struct
(fun q (loc : location) ->
let rec iter () =
try
- let cond = List.find
- (fun c ->
+ let cond, _ = List.find
+ (fun (c, _) ->
is_bot (apply_cl loc.v (conslist_of_f c))
|| is_bot (apply_cl loc.v (conslist_of_f (BNot c))))
(ternary_conds loc.f)
@@ -512,56 +527,113 @@ end = struct
e.loc;
(* find splitting condition *)
+ let voi = List.map (fun (a, b, c) -> c) e.guarantees in
+
Hashtbl.iter
(fun q (loc:location) ->
if loc.depth < e.opt.max_dp_depth then
let cs = ternary_conds loc.f in
List.iter
- (fun c ->
+ (fun (c, exprs) ->
let cases_t = apply_cl_all_cases (top e) (conslist_of_f c) in
let cases_f = apply_cl_all_cases (top e) (conslist_of_f (BNot c)) in
- let cases = cases_t @ cases_f in
+ let cases = List.mapi (fun i c -> i, c) (cases_t @ cases_f) in
if
List.length
(List.filter
- (fun case -> not (is_bot (meet loc.v case)))
+ (fun (_, case) -> not (is_bot (meet loc.v case)))
cases)
>= 2
then
- let score =
- let w1 = List.fold_left (+.) 0.
- (List.map
- (fun qi ->
- let loci = Hashtbl.find e.loc qi in
- let v = apply_cl
- (meet (pass_cycle e.ve loci.v) loc.def)
- loc.cl in
- let n = List.length @@ List.filter
- (fun case -> is_bot (meet v case))
- cases
- in if n > 0 then 1. else 0.)
- loc.in_t)
- in
- let w2 = List.fold_left (+.) 0.
- (List.map
- (fun qo ->
- let loco = Hashtbl.find e.loc qo in
- let n = List.length @@ List.filter
- (fun case ->
- let v = meet loc.v case in
+ (* calculate which transitions qi -> q stay or are destroyed (approximation) *)
+ let in_tc =
+ List.flatten @@ List.map
+ (fun qi ->
+ let loci = Hashtbl.find e.loc qi in
+ let v = apply_cl
+ (meet (pass_cycle e.ve loci.v) loc.def)
+ loc.cl in
+ List.map
+ (fun (ci, case) -> qi, ci, not (is_bot (meet v case)))
+ cases)
+ loc.in_t
+ in
+ (* calculate which transitions q -> qo stay or are destroyed (approximation) *)
+ let out_tc =
+ List.flatten @@ List.map
+ (fun (ci, case) ->
+ let v = meet loc.v case in
+ List.map
+ (fun qo ->
+ let loco = Hashtbl.find e.loc qo in
let w = apply_cl
(meet (pass_cycle e.ve v) loco.def)
loco.cl
- in is_bot w)
- cases
- in if n > 0 then 1. else 0.)
- loc.out_t)
+ in qo, ci, not (is_bot w))
+ loc.out_t)
+ cases
+ in
+ (* calculate which cases have a good number of disappearing transitions *)
+ let fa =
+ let cs_sc =
+ List.map
+ (fun (ci, case) ->
+ let a =
+ List.length
+ (List.filter (fun (qi, c, a) -> not a && c = ci) in_tc)
+ in
+ let b =
+ List.length
+ (List.filter (fun (qo, c, a) -> not a && c = ci) out_tc)
+ in
+ a + b + a * b)
+ cases
in
- ((w1 /. 1. (*(float_of_int (List.length loc.in_t))*))
- +. (w2 /. 1. (*(float_of_int (List.length loc.out_t))*)))
- /. (float_of_int (List.length cases))
+ 5 * List.fold_left max 0 cs_sc
+ in
+ (* calculate which states become inaccessible *)
+ let fb =
+ if fa = 0 || List.for_all (fun (_, _, a) -> a) out_tc
+ then 0
+ else
+ let ff id = (* transition function for new graph *)
+ if id >= 1000000 then
+ let case = id - 1000000 in
+ List.map (fun (qo, _, _) -> qo)
+ (List.filter (fun (_, c, a) -> c = case && a) out_tc)
+ else
+ let out_t = (Hashtbl.find e.loc id).out_t in
+ if List.mem loc.id out_t then
+ (List.map (fun (_, c, _) -> c + 1000000)
+ (List.filter (fun (qi, _, a) -> qi = id && a) in_tc))
+ @ (List.filter ((<>) id) out_t)
+ else out_t
+ in
+ let memo = Hashtbl.create 12 in
+ let rec do_x id =
+ if not (Hashtbl.mem memo id) then begin
+ Hashtbl.add memo id ();
+ List.iter do_x (ff id)
+ end
+ in
+ Hashtbl.iter (fun i loc2 -> if loc2.is_init && i <> loc.id then do_x i) e.loc;
+ if loc.is_init then List.iter (fun (ci, _) -> do_x (ci+1000000)) cases;
+ let disappear_count = (Hashtbl.length e.loc + List.length cases) - (Hashtbl.length memo) in
+ 21 * disappear_count
+ in
+ (* calculate number of VOI (variables of interest) that are affected *)
+ let fc =
+ let vlist = refd_evars_of_f exprs in
+ 10 * List.length
+ (List.filter (fun v -> List.mem v vlist) voi) in
+ (* give score to split *)
+ let fd = 2 * List.length loc.out_t + List.length loc.in_t in
+ let score =
+ if fa = 0 then 0 else
+ fa + fb + fc + fd
in
- if
+ Format.printf " %5d + %5d + %5d + %5d = %5d (q%d)@." fa fb fc fd score loc.id;
+ if score > 0 &&
match !qc with
| None -> true
| Some (s, _, _, _, _) -> score >= s
@@ -604,7 +676,7 @@ end = struct
in iter 0;
(* Check guarantees *)
- let check_guarantee (id, f) =
+ let check_guarantee (id, f, _) =
Format.printf "@[<v 4>";
let cl = Formula.conslist_of_f f in
Format.printf "%s:@ %a ⇒ ⊥ @ "