open Ast open Ast_util open Formula open Typing open Cmdline open Util open Num_domain open Enum_domain open Varenv module I (ED : ENUM_ENVIRONMENT_DOMAIN) (ND : NUMERICAL_ENVIRONMENT_DOMAIN) : sig val do_prog : cmdline_opt -> rooted_prog -> unit end = struct type abs_v = ED.t * ND.t type location = { id : int; depth : int; mutable def : abs_v; is_init : bool; mutable f : bool_expr; mutable cl : conslist; (* For chaotic iteration fixpoint *) mutable in_c : int; mutable v : abs_v; mutable out_t : int list; mutable in_t : int list; mutable verif_g : id list; mutable violate_g : id list; } type env = { rp : rooted_prog; opt : cmdline_opt; ve : varenv; (* program expressions *) f : bool_expr; guarantees : (id * bool_expr * id) list; (* data *) loc : (int, location) Hashtbl.t; counter : int ref; } (* ************************** ABSTRACT VALUES ************************** *) (* top : env -> abs_v bottom : env -> abs_v *) let top e = (ED.top e.ve.evars, ND.top e.ve.nvars) let bottom e = (ED.top e.ve.evars, ND.bottom e.ve.nvars) let is_bot (e, n) = ED.is_bot e || ND.is_bot n let print_v fmt (enum, num) = if is_bot (enum, num) then Format.fprintf fmt "⊥" else Format.fprintf fmt "@[(%a,@ %a)@]" ED.print enum ND.print num (* join : abs_v -> abs_v -> abs_v widen : abs_v -> abs_v -> abs_v meet : abs_v -> abs_v -> abs_v *) let join a b = if is_bot a then b else if is_bot b then a else (ED.join (fst a) (fst b), ND.join (snd a) (snd b)) let widen a b = if is_bot a then b else if is_bot b then a else (ED.join (fst a) (fst b), ND.widen (snd a) (snd b)) let meet (e1, n1) (e2, n2) = if is_bot (e1, n1) then ED.vtop e1, ND.vbottom n1 else if is_bot (e2, n2) then ED.vtop e2, ND.vbottom n2 else try (ED.meet e1 e2 , ND.meet n1 n2) with Bot -> ED.vtop e1, ND.vbottom n1 (* eq_v : abs_v -> abs_v -> bool subset_v : abs_v -> abs_v -> bool *) let eq_v (a, b) (c, d) = (is_bot (a, b) && is_bot (c, d)) || (ED.eq a c && ND.eq b d) let subset_v (a, b) (c, d) = (is_bot (a, b)) || (not (is_bot (c, d)) && ED.subset a c && ND.subset b d) (* apply_cl : abs_v -> conslist -> abs_v *) let rec apply_cl (enum, num) (ec, nc, r) = begin match r with | CLTrue -> begin try (ED.apply_cl enum ec, ND.apply_cl num nc) with Bot -> ED.vtop enum, ND.vbottom num end | CLFalse -> (ED.vtop enum, ND.vbottom num) | CLAnd(a, b) -> let enum, num = apply_cl (enum, num) (ec, nc, a) in let enum, num = apply_cl (enum, num) ([], nc, b) in enum, num | CLOr((eca, nca, ra), (ecb, ncb, rb)) -> let a = apply_cl (enum, num) (ec@eca, nc@nca, ra) in let b = apply_cl (enum, num) (ec@ecb, nc@ncb, rb) in join a b end (* apply_cl_all_cases : abs_v -> conslist -> abs_v list *) let rec apply_cl_all_cases v (ec, nc, r) = match r with | CLTrue -> let v = try ED.apply_cl (fst v) ec, ND.apply_cl (snd v) nc with Bot -> ED.vtop (fst v), ND.vbottom (snd v) in if is_bot v then [] else [v] | CLFalse -> [] | CLAnd(a, b) -> let q1 = apply_cl_all_cases v (ec, nc, a) in List.flatten (List.map (fun c -> apply_cl_all_cases c ([], [], b)) q1) | CLOr((eca, nca, ra), (ecb, ncb, rb)) -> let la = apply_cl_all_cases v (ec@eca, nc@nca, ra) in let lb = apply_cl_all_cases v (ec@ecb, nc@ncb, rb) in lb@(List.filter (fun a -> not (List.exists (fun b -> eq_v a b) lb)) la) (* *************************** INTERPRET *************************** *) (* init_env : cmdline_opt -> rooted_prog -> env *) let init_env opt rp = let f = Transform.f_of_prog_incl_init rp false in let f = simplify_k (get_root_true f) f in Format.printf "Complete formula:@.%a@.@." Formula_printer.print_expr f; let facts = get_root_true f in let f, rp, repls = List.fold_left (fun (f, (rp : rooted_prog), repls) eq -> match eq with | BEnumCons(E_EQ, a, EIdent b) when a.[0] <> 'L' && b.[0] <> 'L' -> let a = try List.assoc a repls with Not_found -> a in let b = try List.assoc b repls with Not_found -> b in if a = b then f, rp, repls else begin let keep, repl = if String.length a <= String.length b then a, b else b, a in Format.printf "Replacing %s with %s@." repl keep; let f = formula_replace_evars [repl, keep; "L"^repl, "L"^keep] f in let rp = { rp with all_vars = List.filter (fun (_, id, _) -> id <> repl) rp.all_vars } in let repls = [repl, keep; "L"^repl, "L"^keep]@ (List.map (fun (r, k) -> r, if k = repl then keep else k) repls) in f, rp, repls end | _ -> f, rp, repls) (f, rp, []) facts in let f = simplify_k (get_root_true f) f in Format.printf "Complete formula after simpl:@.%a@.@." Formula_printer.print_expr f; let guarantees = Transform.guarantees_of_prog rp in let guarantees = List.map (fun (id, f, v) -> id, formula_replace_evars repls f, v) guarantees in Format.printf "Guarantees:@."; List.iter (fun (id, f, _) -> Format.printf " %s: %a@." id Formula_printer.print_expr f) guarantees; Format.printf "@."; let ve = mk_varenv rp f (conslist_of_f f) in let env = { rp; opt; ve; f; guarantees; loc = Hashtbl.create 2; counter = ref 2; } in (* add initial disjunction : L/must_reset = tt, L/must_reset ≠ tt *) let rstc = BEnumCons(E_EQ, "L/must_reset", EItem bool_true) in let rstf = simplify_k [rstc] f in let rstf = simplify_k (get_root_true rstf) rstf in let nrstc = BEnumCons(E_NE, "L/must_reset", EItem bool_true) in let nrstf = simplify_k [nrstc] f in let nrstf = simplify_k (get_root_true nrstf) nrstf in Hashtbl.add env.loc 0 { id = 0; depth = 0; def = apply_cl (top env) (conslist_of_f rstc); is_init = true; f = rstf; cl = conslist_of_f rstf; in_c = 0; v = bottom env; out_t = []; in_t = []; verif_g = []; violate_g = []; }; Hashtbl.add env.loc 1 { id = 1; depth = 0; def = apply_cl (top env) (conslist_of_f nrstc); is_init = false; f = nrstf; cl = conslist_of_f nrstf; in_c = 0; v = bottom env; out_t = []; in_t = []; verif_g = []; violate_g = []; }; env (* ternary_conds : bool_expr -> bool_expr list *) let rec ternary_conds = function | BAnd(a, b) -> ternary_conds a @ ternary_conds b | BTernary(c, a, b) as x -> [c, x] | _ -> [] (* pass_cycle : env -> edd_v -> edd_v unpass_cycle : env -> edd_v -> edd_v set_target_case : env -> edd_v -> bool_expr -> edd_v cycle : env -> edd_v -> conslist -> edd_v *) let pass_cycle env (enum, num) = let assign_e, assign_n = List.fold_left (fun (ae, an) (a, b, t) -> match t with | TEnum _ -> (a, b)::ae, an | TInt | TReal -> ae, (a, NIdent b)::an) ([], []) env.cycle in let enum = ED.assign enum assign_e in let num = ND.assign num assign_n in let ef, nf = List.fold_left (fun (ef, nf) (var, t) -> match t with | TEnum _ -> var::ef, nf | TReal | TInt -> ef, var::nf) ([], []) env.forget in (ED.forgetvars enum ef, List.fold_left ND.forgetvar num nf) let unpass_cycle env (enum, num) = let assign_e, assign_n = List.fold_left (fun (ae, an) (a, b, t) -> match t with | TEnum _ -> (b, a)::ae, an | TInt | TReal -> ae, (b, NIdent a)::an) ([], []) env.ve.cycle in let enum = ED.assign enum assign_e in let num = ND.assign num assign_n in let ef, nf = List.fold_left (fun (ef, nf) (var, t) -> match t with | TEnum _ -> var::ef, nf | TReal | TInt -> ef, var::nf) ([], []) env.ve.forget_inv in (ED.forgetvars enum ef, List.fold_left ND.forgetvar num nf) (* print_locs : env -> unit *) let print_locs_defs e = Hashtbl.iter (fun id loc -> Format.printf "q%d: @[%a@]@." id print_v loc.def; ) e.loc let print_locs e = Hashtbl.iter (fun id loc -> Format.printf "@."; Format.printf "q%d (depth = %d):@. D: @[%a@]@." id loc.depth print_v loc.def; (*Format.printf " F: (%a)@." Formula_printer.print_expr loc.f;*) Format.printf " V: %a@." print_v loc.v; Format.printf " -> @[[%a]@]@." (print_list (fun fmt i -> Format.fprintf fmt "q%d" i) ", ") loc.out_t; ) e.loc let dump_graphwiz_trans_graph e file = let o = open_out file in let fmt = Format.formatter_of_out_channel o in Format.fprintf fmt "digraph G{@."; Hashtbl.iter (fun id loc -> if loc.is_init then Format.fprintf fmt " q%d [shape=doublecircle, label=\"q%d [%a]\"];@." id id (print_list Format.pp_print_string ", ") loc.violate_g else Format.fprintf fmt " q%d [label=\"q%d [%a]\"];@." id id (print_list Format.pp_print_string ", ") loc.violate_g; let n1 = List.length loc.violate_g in List.iter (fun v -> let n2 = List.length (Hashtbl.find e.loc v).violate_g in let c, w = if n2 > n1 then "#770000", 1 else "black", 2 in Format.fprintf fmt " q%d -> q%d [color = \"%s\", weight = %d];@." id v c w) loc.out_t) e.loc; Format.fprintf fmt "}@."; close_out o (* chaotic_iter : env -> unit Fills the values of loc[*].v, and updates out_t and in_t *) let chaotic_iter e = let delta = ref [] in (* Fill up initial states *) Hashtbl.iter (fun q loc -> loc.out_t <- []; loc.in_t <- []; loc.in_c <- 0; if loc.is_init then begin loc.v <- apply_cl (top e) loc.cl; delta := q::!delta end else loc.v <- bottom e) e.loc; (*print_locs_defs e;*) (* Iterate *) let it_counter = ref 0 in while !delta <> [] do let s = List.hd !delta in let loc = Hashtbl.find e.loc s in incr it_counter; Format.printf "@.Iteration %d: q%d@." !it_counter s; let start = loc.v in let f i = (*Format.printf "I: %a@." print_v i;*) let i' = meet i (unpass_cycle e loc.def) in (*Format.printf "I': %a@." print_v i';*) let j = join start (apply_cl (meet (pass_cycle e.ve i') loc.def) loc.cl) in (*Format.printf "J: %a@." print_v j;*) j in let rec iter n i = let fi = f i in let j = if n < e.opt.widen_delay then join i fi else widen i fi in if eq_v i j then i else iter (n+1) j in let y = iter 0 start in let z = f y in let u = pass_cycle e.ve z in if e.opt.verbose_ci then Format.printf "Fixpoint: %a@. mem fp: %a@." print_v z print_v u; loc.v <- z; Hashtbl.iter (fun t loc2 -> let v = meet u loc2.def in let w = apply_cl v loc2.cl in (*Format.printf "u: %a@.v: %a@. w: %a@." print_v u print_v v print_v w;*) if not (is_bot w) then begin if e.opt.verbose_ci then Format.printf "%d -> %d with:@. %a@." s t print_v w; if not (List.mem s loc2.in_t) then loc2.in_t <- s::loc2.in_t; if not (List.mem t loc.out_t) then loc.out_t <- t::loc.out_t; if not (subset_v w loc2.v) then begin if loc2.in_c < e.opt.widen_delay then loc2.v <- join loc2.v w else loc2.v <- widen loc2.v w; loc2.in_c <- loc2.in_c + 1; if not (List.mem t !delta) then delta := t::!delta end end) e.loc; delta := List.filter ((<>) s) !delta; done; (* remove useless locations *) let useless = ref [] in Hashtbl.iter (fun i loc -> if is_bot loc.v then begin Format.printf "Useless location detected: q%d@." i; useless := i::!useless end) e.loc; List.iter (Hashtbl.remove e.loc) !useless; (* check which states verify/violate guarantees *) Hashtbl.iter (fun _ loc -> let verif, violate = List.partition (fun (_, f, _) -> is_bot (apply_cl loc.v (conslist_of_f f))) e.guarantees in loc.verif_g <- List.map (fun (a, b, c) -> a) verif; loc.violate_g <- List.map (fun (a, b, c) -> a) violate) e.loc; print_locs e; () let do_prog opt rp = let e = init_env opt rp in let rec iter n = Format.printf "@.--------------@.Refinement #%d@." n; chaotic_iter e; dump_graphwiz_trans_graph e (Format.sprintf "/tmp/part%03d.dot" n); let qc = ref None in if Hashtbl.length e.loc < e.opt.max_dp_width then begin (* put true or false conditions into location definition *) Hashtbl.iter (fun q (loc : location) -> let rec iter () = try let cond, _ = List.find (fun (c, _) -> is_bot (apply_cl loc.v (conslist_of_f c)) || is_bot (apply_cl loc.v (conslist_of_f (BNot c)))) (ternary_conds loc.f) in let tr = if is_bot (apply_cl loc.v (conslist_of_f cond)) then BNot cond else cond in loc.def <- apply_cl loc.def (conslist_of_f tr); loc.f <- simplify_k [tr] loc.f; loc.f <- simplify_k (get_root_true loc.f) loc.f; loc.cl <- conslist_of_f loc.f; iter() with Not_found -> () in iter ()) e.loc; (* find splitting condition *) let voi = List.map (fun (a, b, c) -> c) e.guarantees in Hashtbl.iter (fun q (loc:location) -> if loc.depth < e.opt.max_dp_depth then let cs = ternary_conds loc.f in List.iter (fun (c, exprs) -> let cases_t = apply_cl_all_cases (top e) (conslist_of_f c) in let cases_f = apply_cl_all_cases (top e) (conslist_of_f (BNot c)) in let cases = List.mapi (fun i c -> i, c) (cases_t @ cases_f) in if List.length (List.filter (fun (_, case) -> not (is_bot (meet loc.v case))) cases) >= 2 then (* calculate which transitions qi -> q stay or are destroyed (approximation) *) let in_tc = List.flatten @@ List.map (fun qi -> let loci = Hashtbl.find e.loc qi in let v = apply_cl (meet (pass_cycle e.ve loci.v) loc.def) loc.cl in List.map (fun (ci, case) -> qi, ci, not (is_bot (meet v case))) cases) loc.in_t in (* calculate which transitions q -> qo stay or are destroyed (approximation) *) let out_tc = List.flatten @@ List.map (fun (ci, case) -> let v = meet loc.v case in List.map (fun qo -> let loco = Hashtbl.find e.loc qo in let w = apply_cl (meet (pass_cycle e.ve v) loco.def) loco.cl in qo, ci, not (is_bot w)) loc.out_t) cases in (* calculate which cases have a good number of disappearing transitions *) let fa = let cs_sc = List.map (fun (ci, case) -> let a = List.length (List.filter (fun (qi, c, a) -> not a && c = ci) in_tc) in let b = List.length (List.filter (fun (qo, c, a) -> not a && c = ci) out_tc) in a * b + (if a + b > 0 then 1 else 0)) cases in let a = List.fold_left max 0 cs_sc in let b = 5 * a in if a > 0 && b = 0 then 1 else b in if fa <> 0 then begin (* calculate which states become inaccessible *) let fb = if List.for_all (fun (_, _, a) -> a) out_tc then 0 else let ff id = (* transition function for new graph *) if id >= 1000000 then let case = id - 1000000 in List.map (fun (qo, _, _) -> qo) (List.filter (fun (_, c, a) -> c = case && a) out_tc) else let out_t = (Hashtbl.find e.loc id).out_t in if List.mem loc.id out_t then (List.map (fun (_, c, _) -> c + 1000000) (List.filter (fun (qi, _, a) -> qi = id && a) in_tc)) @ (List.filter ((<>) id) out_t) else out_t in let memo = Hashtbl.create 12 in let rec do_x id = if not (Hashtbl.mem memo id) then begin Hashtbl.add memo id (); List.iter do_x (ff id) end in Hashtbl.iter (fun i loc2 -> if loc2.is_init && i <> loc.id then do_x i) e.loc; if loc.is_init then List.iter (fun (ci, _) -> do_x (ci+1000000)) cases; let disappear_count = (Hashtbl.length e.loc + List.length cases) - (Hashtbl.length memo) in 21 * disappear_count in (* calculate in/out count, weighted by changing guarantees *) let fc = 0 * (2 * List.length loc.out_t + List.length loc.in_t) in (* calculate number of VOI (variables of interest) that are affected *) let fd = let vlist = refd_evars_of_f exprs in 3 * List.length (List.filter (fun v -> List.mem v vlist) voi) in (* give score to split *) let score = if fa = 0 then 0 else fa + fb + fc + fd in Format.printf " %5d + %5d + %5d + %5d = %5d (q%d)@." fa fb fc fd score loc.id; if score > 0 && match !qc with | None -> true | Some (s, _, _, _, _) -> score >= s then qc := Some(score, q, c, cases_t, cases_f) end) cs ) e.loc; match !qc with | None -> Format.printf "@.Found no more possible refinement.@.@." | Some (score, q, c, cases_t, cases_f) -> Format.printf "@.Refine q%d : @[[ %a ]@]@." q (print_list print_v ", ") (cases_t@cases_f); let loc = Hashtbl.find e.loc q in Hashtbl.remove e.loc loc.id; let handle_case cc case = if not (is_bot (meet loc.v case)) then let ff = simplify_k [cc] loc.f in let ff = simplify_k (get_root_true ff) ff in let loc2 = { loc with id = (incr e.counter; !(e.counter)); depth = loc.depth + 1; def = meet loc.def case; f = ff; cl = conslist_of_f ff } in Hashtbl.add e.loc loc2.id loc2 in List.iter (handle_case c) cases_t; List.iter (handle_case (BNot c)) cases_f; iter (n+1) end in iter 0; (* Check guarantees *) let check_guarantee (id, f, _) = Format.printf "@["; let cl = Formula.conslist_of_f f in Format.printf "%s:@ %a ⇒ ⊥ @ " id Formula_printer.print_conslist cl; let violate = ref [] in Hashtbl.iter (fun lid loc -> if List.mem id loc.violate_g then violate := lid::!violate) e.loc; if !violate = [] then Format.printf "OK" else Format.printf "VIOLATED in @[[ %a ]@]" (print_list (fun fmt i -> Format.fprintf fmt "q%d" i) ", ") !violate; Format.printf "@]@ "; in if e.guarantees <> [] then begin Format.printf "Guarantee @["; List.iter check_guarantee e.guarantees; Format.printf "@]@." end; (* Examine probes *) if List.exists (fun (p, _, _) -> p) e.ve.all_vars then begin let final = Hashtbl.fold (fun _ loc v -> join v loc.v) e.loc (bottom e) in Format.printf "Probes: @["; List.iter (fun (p, id, ty) -> if p then match ty with | TInt | TReal -> Format.printf "%a ∊ %a@ " Formula_printer.print_id id ND.print_itv (ND.project (snd final) id) | TEnum _ -> Format.printf "%a : enum variable@ " Formula_printer.print_id id) e.ve.all_vars; Format.printf "@]@." end end