open Ast
open Util
open Typing
open Formula
type item = string
type evar = id * item list
type nvar = id * bool
type varenv = {
evars : evar list;
nvars : nvar list;
ev_order : (id, int) Hashtbl.t;
last_vars : (bool * id * typ) list;
all_vars : (bool * id * typ) list;
d_vars : id list;
cycle : (id * id * typ) list; (* s'(x) = s(y) *)
forget : (id * typ) list; (* s'(x) not specified *)
forget_inv : (id * typ) list;
}
(*
Extract variables accessed by a LAST
*)
let rec extract_last_vars = function
| BRel(_, a, b, _) ->
elv_ne a @ elv_ne b
| BEnumCons c ->
elv_ec c
| BAnd(a, b) | BOr (a, b) ->
extract_last_vars a @ extract_last_vars b
| BNot(e) -> extract_last_vars e
| BTernary(c, a, b) -> extract_last_vars c @
extract_last_vars a @ extract_last_vars b
| _ -> []
and elv_ne = function
| NIdent i when i.[0] = 'L' -> [i]
| NBinary(_, a, b, _) -> elv_ne a @ elv_ne b
| NUnary (_, a, _) -> elv_ne a
| _ -> []
and elv_ec (_, v, q) =
(if v.[0] = 'L' then [v] else []) @
(match q with
| EIdent i when i.[0] = 'L' -> [i]
| _ -> [])
(*
extract_linked_evars : conslist -> (id * id) list
Extract all pairs of enum-type variable (x, y) appearing in an
equation like x = y or x != y
A couple may appear several times in the result.
*)
let rec extract_linked_evars_root (ecl, _, r) =
let v_ecl = List.fold_left
(fun c (_, x, v) -> match v with
| EIdent y -> (x, y)::c
| _ -> c)
[] ecl
in
v_ecl
let rec extract_const_vars_root (ecl, _, _) =
List.fold_left
(fun l (_, x, v) -> match v with
| EItem _ -> x::l
| _ -> l)
[] ecl
let extract_choice_groups f =
let rec aux w = function
| BNot n -> aux w n
| BRel _ | BConst _ -> [], []
| BEnumCons(_, x, EItem _) -> [], [x]
| BEnumCons(_, x, EIdent y) -> [], [y]
| BAnd(a, b) | BOr(a, b) ->
let ga, va = aux w a in
let gb, vb = aux w b in
ga@gb, va@vb
| BTernary(c, a, b) ->
let gc, vc = aux (w /. 3.) c in
let ga, va = aux (w /. 2.) a in
let gb, vb = aux (w /. 2.) b in
let v = uniq_sorted (List.sort compare (vc@va@vb)) in
(w, v)::(gc@ga@gb), v
in
fst (aux 0.6 f)
(*
scope_constrict : id list -> (id * id) list -> id list
Orders the variable in the first argument such as to minimize the
sum of the distance between the position of two variables appearing in
a couple of the second list. (minimisation is approximate, this is
an heuristic so that the EDD will not explode in size when expressing
equations such as x = y && u = v && a != b)
*)
let scope_constrict vars cp_id =
let var_i = Array.of_list vars in
let n = Array.length var_i in
let i_var = Hashtbl.create n in
Array.iteri (fun i v -> Hashtbl.add i_var v i) var_i;
let cp_i = List.map
(fun (x, y) -> Hashtbl.find i_var x, Hashtbl.find i_var y)
cp_id in
let eval i =
let r = Array.make n (-1) in
Array.iteri (fun pos var -> r.(var) <- pos) i;
Array.iteri (fun _ x -> assert (x <> (-1))) r;
List.fold_left
(fun s (x, y) -> s + abs (r.(x) - r.(y)))
0 cp_i
in
let best = Array.init n (fun i -> i) in
let usefull = ref true in
Format.printf "SCA";
while !usefull do
Format.printf ".@?";
usefull := false;
let try_s x =
if eval x < eval best then begin
Array.blit x 0 best 0 n;
usefull := true
end
in
for i = 0 to n-1 do
let tt = Array.copy best in
(* move item i at beginning *)
let temp = tt.(i) in
for j = i downto 1 do tt.(j) <- tt.(j-1) done;
tt.(0) <- temp;
(* try all positions *)
try_s tt;
for j = 1 to n-1 do
let temp = tt.(j-1) in
tt.(j-1) <- tt.(j);
tt.(j) <- temp;
try_s tt
done
done
done;
Format.printf "@.";
Array.to_list (Array.map (Array.get var_i) best)
(*
force_ordering : id list -> (float * id list) list -> id list
Determine a good ordering for enumerate variables based on the FORCE algorithm
*)
let force_ordering vars groups =
let var_i = Array.of_list vars in
let n = Array.length var_i in
let i_var = Hashtbl.create n in
Array.iteri (fun i v -> Hashtbl.add i_var v i) var_i;
Hashtbl.add i_var "#BEGIN" (-1);
let ngroups = List.map
(fun (w, l) -> w, List.map (Hashtbl.find i_var) l)
groups in
let ord = Array.init n (fun i -> i) in
for iter = 0 to 500 do
let rev = Array.make n (-1) in
for i = 0 to n-1 do rev.(ord.(i)) <- i done;
let bw = Array.make n 0. in
let w = Array.make n 0. in
let gfun (gw, l) =
let sp = List.fold_left (+.) 0.
(List.map
(fun i -> if i = -1 then -.gw else float_of_int (rev.(i))) l)
in
let b = sp /. float_of_int (List.length l) in
List.iter (fun i -> if i >= 0 then begin
bw.(i) <- bw.(i) +. (gw *. b);
w.(i) <- w.(i) +. gw end)
l
in
List.iter gfun ngroups;
let b = Array.init n
(fun i ->
if w.(i) = 0. then
float_of_int i
else bw.(i) /. w.(i)) in
let ol = List.sort
(fun i j -> Pervasives.compare b.(i) b.(j))
(Array.to_list ord) in
Array.blit (Array.of_list ol) 0 ord 0 n
done;
List.map (Array.get var_i) (Array.to_list ord)
(*
Make varenv : takes a program, and extracts
- list of enum variables
- list of num variables
- good order for enum variables
- cycle, forget
*)
let mk_varenv (rp : rooted_prog) disj_fun f cl =
(* add variables from LASTs *)
let last_vars = uniq_sorted
(List.sort compare (extract_last_vars f)) in
let last_vars = List.map
(fun id ->
let (_, _, ty) = List.find (fun (_, u, _) -> id = "L"^u) rp.all_vars
in false, id, ty)
last_vars in
let all_vars = last_vars @ rp.all_vars in
Format.printf "Vars: @[<hov>%a@]@.@."
(print_list Ast_printer.print_typed_var ", ")
all_vars;
let num_vars, enum_vars = List.fold_left
(fun (nv, ev) (_, id, t) -> match t with
| TEnum ch -> nv, (id, ch)::ev
| TInt -> (id, false)::nv, ev
| TReal -> (id, true)::nv, ev)
([], []) all_vars in
(* calculate order for enumerated variables *)
let evars = List.map fst enum_vars in
let lv = extract_linked_evars_root cl in
let lv = uniq_sorted
(List.sort Pervasives.compare (List.map ord_couple lv)) in
let lv_f = List.map (fun (a, b) -> (1.0, [a; b])) lv in
let lv_f = lv_f @ (List.map (fun v -> (10.0, ["#BEGIN"; v]))
(extract_const_vars_root cl)) in
let lv_f = lv_f @ (List.map (fun v -> (7.0, ["#BEGIN"; v]))
(List.filter (fun n -> is_suffix n "init") evars)) in
let lv_f = lv_f @ (List.map (fun v -> (3.0, ["#BEGIN"; v]))
(List.filter (fun n -> is_suffix n "act" || is_suffix n "state") evars)) in
let lv_f = lv_f @
(List.map (fun v -> (0.7, [v; "L"^v]))
(List.filter (fun n -> List.mem ("L"^n) evars) evars)) in
let lv_f = lv_f @ (extract_choice_groups f) in
let evars_ord =
if true then
time "FORCE" (fun () -> force_ordering evars lv_f)
else
time "SCA" (fun () -> scope_constrict evars lv)
in
let evars_ord =
if false then
let va, vb = List.partition (fun n -> is_suffix n "init") evars_ord in
let vb, vc = List.partition (fun n -> is_suffix n "state") vb in
(List.rev va) @ vb @ vc
else
evars_ord
in
let ev_order = Hashtbl.create (List.length evars) in
List.iteri (fun i x -> Hashtbl.add ev_order x i) evars_ord;
let enum_vars = List.sort
(fun (id1, _) (id2, _) ->
compare (Hashtbl.find ev_order id1) (Hashtbl.find ev_order id2))
enum_vars
in
Format.printf "Order for variables: @[<hov>[%a]@]@."
(print_list Formula_printer.print_id ", ") evars_ord;
(* calculate cycle variables and forget variables *)
let cycle = List.fold_left
(fun q (_, id, ty) ->
if id.[0] = 'L' then
(id, String.sub id 1 (String.length id - 1), ty)::q
else q)
[] last_vars
in
let forget = List.map (fun (_, id, ty) -> (id, ty)) rp.all_vars in
let forget_inv = List.map (fun (_, id, ty) -> (id, ty))
(List.filter
(fun (_, id, _) ->
not (List.exists (fun (_, b, _) -> b = id) cycle))
all_vars) in
(* use specified disjunction variables *)
let d_vars = List.filter disj_fun
(List.map (fun (id, _) -> id) enum_vars) in
Format.printf "Disjunction variables: @[<hov>[%a]@]@."
(print_list Formula_printer.print_id ", ") d_vars;
{ evars = enum_vars; nvars = num_vars; ev_order; d_vars;
last_vars; all_vars; cycle; forget; forget_inv }