open Ast open Util open Typing open Formula type item = string type evar = id * item list type nvar = id * bool type varenv = { evars : evar list; nvars : nvar list; ev_order : (id, int) Hashtbl.t; last_vars : (bool * id * typ) list; all_vars : (bool * id * typ) list; cycle : (id * id * typ) list; (* s'(x) = s(y) *) forget : (id * typ) list; (* s'(x) not specified *) } (* extract_linked_evars : conslist -> (id * id) list Extract all pairs of enum-type variable (x, y) appearing in an equation like x = y or x != y A couple may appear several times in the result. *) let rec extract_linked_evars_root (ecl, _, r) = let v_ecl = List.fold_left (fun c (_, x, v) -> match v with | EIdent y -> (x, y)::c | _ -> c) [] ecl in v_ecl let rec extract_const_vars_root (ecl, _, _) = List.fold_left (fun l (_, x, v) -> match v with | EItem _ -> x::l | _ -> l) [] ecl (* scope_constrict : id list -> (id * id) list -> id list Orders the variable in the first argument such as to minimize the sum of the distance between the position of two variables appearing in a couple of the second list. (minimisation is approximate, this is an heuristic so that the EDD will not explode in size when expressing equations such as x = y && u = v && a != b) *) let scope_constrict vars cp_id = let var_i = Array.of_list vars in let n = Array.length var_i in let i_var = Hashtbl.create n in Array.iteri (fun i v -> Hashtbl.add i_var v i) var_i; let cp_i = List.map (fun (x, y) -> Hashtbl.find i_var x, Hashtbl.find i_var y) cp_id in let eval i = let r = Array.make n (-1) in Array.iteri (fun pos var -> r.(var) <- pos) i; Array.iteri (fun _ x -> assert (x <> (-1))) r; List.fold_left (fun s (x, y) -> s + abs (r.(x) - r.(y))) 0 cp_i in let best = Array.init n (fun i -> i) in let usefull = ref true in Format.printf "SCA"; while !usefull do Format.printf ".@?"; usefull := false; let try_s x = if eval x < eval best then begin Array.blit x 0 best 0 n; usefull := true end in for i = 0 to n-1 do let tt = Array.copy best in (* move item i at beginning *) let temp = tt.(i) in for j = i downto 1 do tt.(j) <- tt.(j-1) done; tt.(0) <- temp; (* try all positions *) try_s tt; for j = 1 to n-1 do let temp = tt.(j-1) in tt.(j-1) <- tt.(j); tt.(j) <- temp; try_s tt done done done; Format.printf "@."; Array.to_list (Array.map (Array.get var_i) best) (* force_ordering : id list -> (float * id list) list -> id list Determine a good ordering for enumerate variables based on the FORCE algorithm *) let force_ordering vars groups = let var_i = Array.of_list vars in let n = Array.length var_i in let i_var = Hashtbl.create n in Array.iteri (fun i v -> Hashtbl.add i_var v i) var_i; Hashtbl.add i_var "#BEGIN" (-1); let ngroups = List.map (fun (w, l) -> w, List.map (Hashtbl.find i_var) l) groups in let ord = Array.init n (fun i -> i) in for iter = 0 to 500 do let rev = Array.make n (-1) in for i = 0 to n-1 do rev.(ord.(i)) <- i done; let bw = Array.make n 0. in let w = Array.make n 0. in let gfun (gw, l) = let sp = List.fold_left (+.) 0. (List.map (fun i -> if i = -1 then -.gw else float_of_int (rev.(i))) l) in let b = sp /. float_of_int (List.length l) in List.iter (fun i -> if i >= 0 then begin bw.(i) <- bw.(i) +. (gw *. b); w.(i) <- w.(i) +. gw end) l in List.iter gfun ngroups; let b = Array.init n (fun i -> if w.(i) = 0. then float_of_int i else bw.(i) /. w.(i)) in let ol = List.sort (fun i j -> Pervasives.compare b.(i) b.(j)) (Array.to_list ord) in Array.blit (Array.of_list ol) 0 ord 0 n done; List.map (Array.get var_i) (Array.to_list ord) (* Make varenv : takes a program, and extracts - list of enum variables - list of num variables - good order for enum variables - cycle, forget *) let mk_varenv (rp : rooted_prog) f cl = (* add variables from LASTs *) let last_vars = uniq_sorted (List.sort compare (Transform.extract_last_vars f)) in let last_vars = List.map (fun id -> let (_, _, ty) = List.find (fun (_, u, _) -> id = "L"^u) rp.all_vars in false, id, ty) last_vars in let all_vars = last_vars @ rp.all_vars in Format.printf "Vars: @[%a@]@.@." (print_list Ast_printer.print_typed_var ", ") all_vars; let num_vars, enum_vars = List.fold_left (fun (nv, ev) (_, id, t) -> match t with | TEnum ch -> nv, (id, ch)::ev | TInt -> (id, false)::nv, ev | TReal -> (id, true)::nv, ev) ([], []) all_vars in (* calculate order for enumerated variables *) let evars = List.map fst enum_vars in let lv = extract_linked_evars_root cl in let lv = uniq_sorted (List.sort Pervasives.compare (List.map ord_couple lv)) in let lv_f = List.map (fun (a, b) -> (1.0, [a; b])) lv in let lv_f = lv_f @ (List.map (fun v -> (10.0, ["#BEGIN"; v])) (extract_const_vars_root cl)) in let lv_f = lv_f @ (List.map (fun v -> (5.0, ["#BEGIN"; v])) (List.filter (fun n -> is_suffix n "init") evars)) in let lv_f = lv_f @ (List.map (fun v -> (3.0, ["#BEGIN"; v])) (List.filter (fun n -> is_suffix n "state") evars)) in let lv_f = lv_f @ (List.map (fun v -> (0.7, [v; "L"^v])) (List.filter (fun n -> List.mem ("L"^n) evars) evars)) in let evars_ord = if true then time "FORCE" (fun () -> force_ordering evars lv_f) else time "SCA" (fun () -> scope_constrict evars lv) in let evars_ord = if false then let va, vb = List.partition (fun n -> is_suffix n "init") evars_ord in let vb, vc = List.partition (fun n -> is_suffix n "state") vb in (List.rev va) @ vb @ vc else evars_ord in let ev_order = Hashtbl.create (List.length evars) in List.iteri (fun i x -> Hashtbl.add ev_order x i) evars_ord; Format.printf "Order for variables: @[[%a]@]@." (print_list Formula_printer.print_id ", ") evars_ord; (* calculate cycle variables and forget variables *) let cycle = List.fold_left (fun q (_, id, ty) -> if id.[0] = 'L' then (id, String.sub id 1 (String.length id - 1), ty)::q else q) [] last_vars in let forget = List.map (fun (_, id, ty) -> (id, ty)) rp.all_vars in { evars = enum_vars; nvars = num_vars; ev_order; last_vars; all_vars; cycle; forget }