open Ast
open Ast_util
open Formula
open Typing
open Cmdline
open Util
open Num_domain
open Enum_domain
open Varenv
module I (ED : ENUM_ENVIRONMENT_DOMAIN) (ND : NUMERICAL_ENVIRONMENT_DOMAIN)
: sig
val do_prog : cmdline_opt -> rooted_prog -> unit
end = struct
type abs_v = ED.t * ND.t
(*
Abstract analysis based on dynamic partitionning of the state space.
Idea : use somme conditions appearing in the text of the program as
disjunctions. We don't want to consider them all at once in the first
place because it would be way too costly ; instead we try to dynamically
partition tye system. But we haven't got a very good heuristic for that,
so it doesn't work very well.
*)
type location = {
id : int;
depth : int;
mutable def : abs_v;
is_init : bool;
mutable f : bool_expr;
mutable cl : conslist;
(* For chaotic iteration fixpoint *)
mutable in_c : int;
mutable v : abs_v;
mutable out_t : int list;
mutable in_t : int list;
mutable verif_g : id list;
mutable violate_g : id list;
}
type env = {
rp : rooted_prog;
opt : cmdline_opt;
ve : varenv;
(* program expressions *)
f : bool_expr;
guarantees : (id * bool_expr * id) list;
(* data *)
loc : (int, location) Hashtbl.t;
counter : int ref;
}
(* **************************
ABSTRACT VALUES
************************** *)
(*
top : env -> abs_v
bottom : env -> abs_v
*)
let top e = (ED.top e.ve.evars, ND.top e.ve.nvars)
let bottom e = (ED.top e.ve.evars, ND.bottom e.ve.nvars)
let is_bot (e, n) = ED.is_bot e || ND.is_bot n
let print_v fmt (enum, num) =
if is_bot (enum, num) then
Format.fprintf fmt "⊥"
else
Format.fprintf fmt "@[<hov 1>(%a,@ %a)@]" ED.print enum ND.print num
(*
join : abs_v -> abs_v -> abs_v
widen : abs_v -> abs_v -> abs_v
meet : abs_v -> abs_v -> abs_v
*)
let join a b =
if is_bot a then b
else if is_bot b then a
else (ED.join (fst a) (fst b), ND.join (snd a) (snd b))
let widen a b =
if is_bot a then b
else if is_bot b then a
else (ED.join (fst a) (fst b), ND.widen (snd a) (snd b))
let meet (e1, n1) (e2, n2) =
if is_bot (e1, n1) then ED.vtop e1, ND.vbottom n1
else if is_bot (e2, n2) then ED.vtop e2, ND.vbottom n2
else
try (ED.meet e1 e2 , ND.meet n1 n2)
with Bot -> ED.vtop e1, ND.vbottom n1
(*
eq_v : abs_v -> abs_v -> bool
subset_v : abs_v -> abs_v -> bool
*)
let eq_v (a, b) (c, d) =
(is_bot (a, b) && is_bot (c, d))
|| (ED.eq a c && ND.eq b d)
let subset_v (a, b) (c, d) =
(is_bot (a, b)) ||
(not (is_bot (c, d)) && ED.subset a c && ND.subset b d)
(*
apply_cl : abs_v -> conslist -> abs_v
*)
let rec apply_cl (enum, num) (ec, nc, r) =
begin match r with
| CLTrue ->
begin
try (ED.apply_cl enum ec, ND.apply_cl num nc)
with Bot -> ED.vtop enum, ND.vbottom num
end
| CLFalse ->
(ED.vtop enum, ND.vbottom num)
| CLAnd(a, b) ->
let enum, num = apply_cl (enum, num) (ec, nc, a) in
let enum, num = apply_cl (enum, num) ([], nc, b) in
enum, num
| CLOr((eca, nca, ra), (ecb, ncb, rb)) ->
let a = apply_cl (enum, num) (ec@eca, nc@nca, ra) in
let b = apply_cl (enum, num) (ec@ecb, nc@ncb, rb) in
join a b
end
(*
apply_cl_all_cases : abs_v -> conslist -> abs_v list
*)
let rec apply_cl_all_cases v (ec, nc, r) =
match r with
| CLTrue ->
let v =
try ED.apply_cl (fst v) ec, ND.apply_cl (snd v) nc
with Bot -> ED.vtop (fst v), ND.vbottom (snd v)
in
if is_bot v then [] else [v]
| CLFalse ->
[]
| CLAnd(a, b) ->
let q1 = apply_cl_all_cases v (ec, nc, a) in
List.flatten
(List.map (fun c -> apply_cl_all_cases c ([], [], b)) q1)
| CLOr((eca, nca, ra), (ecb, ncb, rb)) ->
let la = apply_cl_all_cases v (ec@eca, nc@nca, ra) in
let lb = apply_cl_all_cases v (ec@ecb, nc@ncb, rb) in
lb@(List.filter (fun a -> not (List.exists (fun b -> eq_v a b) lb)) la)
(* ***************************
INTERPRET
*************************** *)
(*
init_env : cmdline_opt -> rooted_prog -> env
*)
let init_env opt rp =
let f = Transform.f_of_prog_incl_init rp false in
let f = simplify_k (get_root_true f) f in
Format.printf "Complete formula:@.%a@.@." Formula_printer.print_expr f;
let facts = get_root_true f in
let f, rp, repls = List.fold_left
(fun (f, (rp : rooted_prog), repls) eq ->
match eq with
| BEnumCons(E_EQ, a, EIdent b)
when a.[0] <> 'L' && b.[0] <> 'L' ->
let a = try List.assoc a repls with Not_found -> a in
let b = try List.assoc b repls with Not_found -> b in
if a = b then
f, rp, repls
else begin
let keep, repl =
if String.length a <= String.length b
then a, b
else b, a
in
Format.printf "Replacing %s with %s@." repl keep;
let f = formula_replace_evars [repl, keep; "L"^repl, "L"^keep] f in
let rp =
{ rp with all_vars =
List.filter (fun (_, id, _) -> id <> repl) rp.all_vars } in
let repls = [repl, keep; "L"^repl, "L"^keep]@
(List.map (fun (r, k) -> r, if k = repl then keep else k) repls) in
f, rp, repls
end
| _ -> f, rp, repls)
(f, rp, []) facts in
let f = simplify_k (get_root_true f) f in
Format.printf "Complete formula after simpl:@.%a@.@."
Formula_printer.print_expr f;
let guarantees = Transform.guarantees_of_prog rp in
let guarantees = List.map
(fun (id, f, v) -> id, formula_replace_evars repls f, v)
guarantees in
Format.printf "Guarantees:@.";
List.iter (fun (id, f, _) ->
Format.printf " %s: %a@." id Formula_printer.print_expr f)
guarantees;
Format.printf "@.";
let ve = mk_varenv rp f (conslist_of_f f) in
let env = {
rp; opt; ve; f; guarantees;
loc = Hashtbl.create 2; counter = ref 2; } in
(* add initial disjunction : L/must_reset = tt, L/must_reset ≠ tt *)
let rstc = BEnumCons(E_EQ, "L/must_reset", EItem bool_true) in
let rstf = simplify_k [rstc] f in
let rstf = simplify_k (get_root_true rstf) rstf in
let nrstc = BEnumCons(E_NE, "L/must_reset", EItem bool_true) in
let nrstf = simplify_k [nrstc] f in
let nrstf = simplify_k (get_root_true nrstf) nrstf in
Hashtbl.add env.loc 0
{
id = 0;
depth = 0;
def = apply_cl (top env) (conslist_of_f rstc);
is_init = true;
f = rstf;
cl = conslist_of_f rstf;
in_c = 0;
v = bottom env;
out_t = [];
in_t = [];
verif_g = [];
violate_g = [];
};
Hashtbl.add env.loc 1
{
id = 1;
depth = 0;
def = apply_cl (top env) (conslist_of_f nrstc);
is_init = false;
f = nrstf;
cl = conslist_of_f nrstf;
in_c = 0;
v = bottom env;
out_t = [];
in_t = [];
verif_g = [];
violate_g = [];
};
env
(*
ternary_conds : bool_expr -> bool_expr list
*)
let rec ternary_conds = function
| BAnd(a, b) -> ternary_conds a @ ternary_conds b
| BTernary(c, a, b) as x -> [c, x]
| _ -> []
(*
pass_cycle : env -> edd_v -> edd_v
unpass_cycle : env -> edd_v -> edd_v
set_target_case : env -> edd_v -> bool_expr -> edd_v
cycle : env -> edd_v -> conslist -> edd_v
*)
let pass_cycle env (enum, num) =
let assign_e, assign_n = List.fold_left
(fun (ae, an) (a, b, t) -> match t with
| TEnum _ -> (a, b)::ae, an
| TInt | TReal -> ae, (a, NIdent b)::an)
([], []) env.cycle in
let enum = ED.assign enum assign_e in
let num = ND.assign num assign_n in
let ef, nf = List.fold_left
(fun (ef, nf) (var, t) -> match t with
| TEnum _ -> var::ef, nf
| TReal | TInt -> ef, var::nf)
([], []) env.forget in
(ED.forgetvars enum ef, List.fold_left ND.forgetvar num nf)
let unpass_cycle env (enum, num) =
let assign_e, assign_n = List.fold_left
(fun (ae, an) (a, b, t) -> match t with
| TEnum _ -> (b, a)::ae, an
| TInt | TReal -> ae, (b, NIdent a)::an)
([], []) env.ve.cycle in
let enum = ED.assign enum assign_e in
let num = ND.assign num assign_n in
let ef, nf = List.fold_left
(fun (ef, nf) (var, t) -> match t with
| TEnum _ -> var::ef, nf
| TReal | TInt -> ef, var::nf)
([], []) env.ve.forget_inv in
(ED.forgetvars enum ef, List.fold_left ND.forgetvar num nf)
(*
print_locs : env -> unit
*)
let print_locs_defs e =
Hashtbl.iter
(fun id loc ->
Format.printf "q%d: @[<v 2>%a@]@." id print_v loc.def;
)
e.loc
let print_locs e =
Hashtbl.iter
(fun id loc ->
Format.printf "@.";
Format.printf "q%d (depth = %d):@. D: @[<v 2>%a@]@." id loc.depth print_v loc.def;
(*Format.printf " F: (%a)@." Formula_printer.print_expr loc.f;*)
Format.printf " V: %a@." print_v loc.v;
Format.printf " -> @[<hov>[%a]@]@."
(print_list (fun fmt i -> Format.fprintf fmt "q%d" i) ", ") loc.out_t;
)
e.loc
let dump_graphwiz_trans_graph e file =
let o = open_out file in
let fmt = Format.formatter_of_out_channel o in
Format.fprintf fmt "digraph G{@.";
Hashtbl.iter
(fun id loc ->
if loc.is_init then
Format.fprintf fmt " q%d [shape=doublecircle, label=\"q%d [%a]\"];@."
id id (print_list Format.pp_print_string ", ") loc.violate_g
else
Format.fprintf fmt " q%d [label=\"q%d [%a]\"];@."
id id (print_list Format.pp_print_string ", ") loc.violate_g;
let n1 = List.length loc.violate_g in
List.iter
(fun v ->
let n2 = List.length (Hashtbl.find e.loc v).violate_g in
let c, w =
if n2 > n1 then "#770000", 1
else "black", 2
in
Format.fprintf fmt " q%d -> q%d [color = \"%s\", weight = %d];@."
id v c w)
loc.out_t)
e.loc;
Format.fprintf fmt "}@.";
close_out o
(*
chaotic_iter : env -> unit
Fills the values of loc[*].v, and updates out_t and in_t
*)
let chaotic_iter e =
let delta = ref [] in
(* Fill up initial states *)
Hashtbl.iter
(fun q loc ->
loc.out_t <- [];
loc.in_t <- [];
loc.in_c <- 0;
if loc.is_init then begin
loc.v <- apply_cl (top e) loc.cl;
delta := q::!delta
end else
loc.v <- bottom e)
e.loc;
(*print_locs_defs e;*)
(* Iterate *)
let it_counter = ref 0 in
while !delta <> [] do
let s = List.hd !delta in
let loc = Hashtbl.find e.loc s in
incr it_counter;
Format.printf "@.Iteration %d: q%d@." !it_counter s;
let start = loc.v in
let f i =
(*Format.printf "I: %a@." print_v i;*)
let i' = meet i (unpass_cycle e loc.def) in
(*Format.printf "I': %a@." print_v i';*)
let j = join start
(apply_cl
(meet (pass_cycle e.ve i') loc.def)
loc.cl) in
(*Format.printf "J: %a@." print_v j;*)
j
in
let rec iter n i =
let fi = f i in
let j =
if n < e.opt.widen_delay then
join i fi
else
widen i fi
in
if eq_v i j
then i
else iter (n+1) j
in
let y = iter 0 start in
let z = f y in
let u = pass_cycle e.ve z in
if e.opt.verbose_ci then
Format.printf "Fixpoint: %a@. mem fp: %a@." print_v z print_v u;
loc.v <- z;
Hashtbl.iter
(fun t loc2 ->
let v = meet u loc2.def in
let w = apply_cl v loc2.cl in
(*Format.printf "u: %a@.v: %a@. w: %a@." print_v u print_v v print_v w;*)
if not (is_bot w) then begin
if e.opt.verbose_ci then
Format.printf "%d -> %d with:@. %a@." s t print_v w;
if not (List.mem s loc2.in_t)
then loc2.in_t <- s::loc2.in_t;
if not (List.mem t loc.out_t)
then loc.out_t <- t::loc.out_t;
if not (subset_v w loc2.v) then begin
if loc2.in_c < e.opt.widen_delay then
loc2.v <- join loc2.v w
else
loc2.v <- widen loc2.v w;
loc2.in_c <- loc2.in_c + 1;
if not (List.mem t !delta)
then delta := t::!delta
end
end)
e.loc;
delta := List.filter ((<>) s) !delta;
done;
(* remove useless locations *)
let useless = ref [] in
Hashtbl.iter
(fun i loc ->
if is_bot loc.v then begin
Format.printf "Useless location detected: q%d@." i;
useless := i::!useless
end)
e.loc;
List.iter (Hashtbl.remove e.loc) !useless;
(* check which states verify/violate guarantees *)
Hashtbl.iter
(fun _ loc ->
let verif, violate = List.partition
(fun (_, f, _) ->
is_bot (apply_cl loc.v (conslist_of_f f)))
e.guarantees
in
loc.verif_g <- List.map (fun (a, b, c) -> a) verif;
loc.violate_g <- List.map (fun (a, b, c) -> a) violate)
e.loc;
print_locs e;
()
let do_prog opt rp =
let e = init_env opt rp in
let rec iter n =
Format.printf "@.--------------@.Refinement #%d@." n;
chaotic_iter e;
dump_graphwiz_trans_graph e (Format.sprintf "/tmp/part%03d.dot" n);
let qc = ref None in
if Hashtbl.length e.loc < e.opt.max_dp_width then begin
(* put true or false conditions into location definition *)
Hashtbl.iter
(fun q (loc : location) ->
let rec iter () =
try
let cond, _ = List.find
(fun (c, _) ->
is_bot (apply_cl loc.v (conslist_of_f c))
|| is_bot (apply_cl loc.v (conslist_of_f (BNot c))))
(ternary_conds loc.f)
in
let tr =
if is_bot (apply_cl loc.v (conslist_of_f cond))
then BNot cond
else cond
in
loc.def <- apply_cl loc.def (conslist_of_f tr);
loc.f <- simplify_k [tr] loc.f;
loc.f <- simplify_k (get_root_true loc.f) loc.f;
loc.cl <- conslist_of_f loc.f;
iter()
with Not_found -> ()
in iter ())
e.loc;
(* find splitting condition *)
let voi = List.map (fun (a, b, c) -> c) e.guarantees in
Hashtbl.iter
(fun q (loc:location) ->
if loc.depth < e.opt.max_dp_depth then
let cs = ternary_conds loc.f in
List.iter
(fun (c, exprs) ->
let cases_t = apply_cl_all_cases (top e) (conslist_of_f c) in
let cases_f = apply_cl_all_cases (top e) (conslist_of_f (BNot c)) in
let cases = List.mapi (fun i c -> i, c) (cases_t @ cases_f) in
if
List.length
(List.filter
(fun (_, case) -> not (is_bot (meet loc.v case)))
cases)
>= 2
then
(* calculate which transitions qi -> q stay or are destroyed (approximation) *)
let in_tc =
List.flatten @@ List.map
(fun qi ->
let loci = Hashtbl.find e.loc qi in
let v = apply_cl
(meet (pass_cycle e.ve loci.v) loc.def)
loc.cl in
List.map
(fun (ci, case) -> qi, ci, not (is_bot (meet v case)))
cases)
loc.in_t
in
(* calculate which transitions q -> qo stay or are destroyed (approximation) *)
let out_tc =
List.flatten @@ List.map
(fun (ci, case) ->
let v = meet loc.v case in
List.map
(fun qo ->
let loco = Hashtbl.find e.loc qo in
let w = apply_cl
(meet (pass_cycle e.ve v) loco.def)
loco.cl
in qo, ci, not (is_bot w))
loc.out_t)
cases
in
(* calculate which cases have a good number of disappearing transitions *)
let fa =
let cs_sc =
List.map
(fun (ci, case) ->
let a =
List.length
(List.filter (fun (qi, c, a) -> not a && c = ci) in_tc)
in
let b =
List.length
(List.filter (fun (qo, c, a) -> not a && c = ci) out_tc)
in
a + b + a * b)
cases
in
let a = List.fold_left max 0 cs_sc in
let b = if a = 0 then 0 else
List.length @@ List.filter
(fun qi ->
let qos = List.flatten @@ List.map
(fun (cid, c) ->
if List.exists (fun (qi0, c0, a) -> a && qi0 = qi && c0 = cid) in_tc
then
List.map (fun (qo, _, _) -> qo) @@
List.filter (fun (_, c1, a) -> a && cid = c1) out_tc
else [])
cases
in
List.exists (fun qo -> not (List.mem qo qos)) loc.out_t)
loc.in_t
in
5 * a + 17 * b
in
if fa <> 0 then begin
(* calculate which states become inaccessible *)
let fb =
if List.for_all (fun (_, _, a) -> a) out_tc
then 0
else
let ff id = (* transition function for new graph *)
if id >= 1000000 then
let case = id - 1000000 in
List.map (fun (qo, _, _) -> qo)
(List.filter (fun (_, c, a) -> c = case && a) out_tc)
else
let out_t = (Hashtbl.find e.loc id).out_t in
if List.mem loc.id out_t then
(List.map (fun (_, c, _) -> c + 1000000)
(List.filter (fun (qi, _, a) -> qi = id && a) in_tc))
@ (List.filter ((<>) id) out_t)
else out_t
in
let memo = Hashtbl.create 12 in
let rec do_x id =
if not (Hashtbl.mem memo id) then begin
Hashtbl.add memo id ();
List.iter do_x (ff id)
end
in
Hashtbl.iter (fun i loc2 -> if loc2.is_init && i <> loc.id then do_x i) e.loc;
if loc.is_init then List.iter (fun (ci, _) -> do_x (ci+1000000)) cases;
let disappear_count = (Hashtbl.length e.loc + List.length cases) - (Hashtbl.length memo) in
21 * disappear_count
in
(* calculate in/out count, weighted by changing guarantees *)
let fc =
1 * (2 * List.length loc.out_t + List.length loc.in_t)
in
(* calculate number of VOI (variables of interest) that are affected *)
let fd =
let vlist = refd_evars_of_f exprs in
3 * List.length
(List.filter (fun v -> List.mem v vlist) voi) in
(* give score to split *)
let score =
if fa = 0 then 0 else
fa + fb + fc + fd
in
Format.printf " %5d + %5d + %5d + %5d = %5d (q%d)@." fa fb fc fd score loc.id;
if score > 0 &&
match !qc with
| None -> true
| Some (s, _, _, _, _) -> score >= s
then
qc := Some(score, q, c, cases_t, cases_f)
end)
cs
)
e.loc;
match !qc with
| None ->
Format.printf "@.Found no more possible refinement.@.@."
| Some (score, q, c, cases_t, cases_f) ->
Format.printf "@.Refine q%d : @[<v 2>[ %a ]@]@." q
(print_list print_v ", ") (cases_t@cases_f);
let loc = Hashtbl.find e.loc q in
Hashtbl.remove e.loc loc.id;
let handle_case cc case =
if not (is_bot (meet loc.v case)) then
let ff = simplify_k [cc] loc.f in
let ff = simplify_k (get_root_true ff) ff in
let loc2 =
{ loc with
id = (incr e.counter; !(e.counter));
depth = loc.depth + 1;
def = meet loc.def case;
f = ff;
cl = conslist_of_f ff } in
Hashtbl.add e.loc loc2.id loc2
in
List.iter (handle_case c) cases_t;
List.iter (handle_case (BNot c)) cases_f;
iter (n+1)
end
in iter 0;
(* Check guarantees *)
let check_guarantee (id, f, _) =
Format.printf "@[<v 4>";
let cl = Formula.conslist_of_f f in
Format.printf "%s:@ %a ⇒ ⊥ @ "
id Formula_printer.print_conslist cl;
let violate = ref [] in
Hashtbl.iter
(fun lid loc ->
if List.mem id loc.violate_g then
violate := lid::!violate)
e.loc;
if !violate = [] then
Format.printf "OK"
else
Format.printf "VIOLATED in @[<hov 2>[ %a ]@]"
(print_list (fun fmt i -> Format.fprintf fmt "q%d" i) ", ") !violate;
Format.printf "@]@ ";
in
if e.guarantees <> [] then begin
Format.printf "Guarantee @[<v 0>";
List.iter check_guarantee e.guarantees;
Format.printf "@]@."
end;
(* Examine probes *)
if List.exists (fun (p, _, _) -> p) e.ve.all_vars then begin
let final =
Hashtbl.fold
(fun _ loc v -> join v loc.v)
e.loc (bottom e) in
Format.printf "Probes: @[<v 0>";
List.iter (fun (p, id, ty) ->
if p then match ty with
| TInt | TReal ->
Format.printf "%a ∊ %a@ " Formula_printer.print_id id
ND.print_itv (ND.project (snd final) id)
| TEnum _ -> Format.printf "%a : enum variable@ "
Formula_printer.print_id id)
e.ve.all_vars;
Format.printf "@]@."
end
end