open Ast
open Util
open Typing
open Formula
type item = string
type evar = id * item list
type nvar = id * bool
type varenv = {
evars : evar list;
nvars : nvar list;
ev_order : (id, int) Hashtbl.t;
last_vars : (bool * id * typ) list;
all_vars : (bool * id * typ) list;
cycle : (id * id * typ) list; (* s'(x) = s(y) *)
forget : (id * typ) list; (* s'(x) not specified *)
}
(*
extract_linked_evars : conslist -> (id * id) list
Extract all pairs of enum-type variable (x, y) appearing in an
equation like x = y or x != y
A couple may appear several times in the result.
*)
let rec extract_linked_evars_root (ecl, _, r) =
let v_ecl = List.fold_left
(fun c (_, x, v) -> match v with
| EIdent y -> (x, y)::c
| _ -> c)
[] ecl
in
v_ecl
let rec extract_const_vars_root (ecl, _, _) =
List.fold_left
(fun l (_, x, v) -> match v with
| EItem _ -> x::l
| _ -> l)
[] ecl
(*
scope_constrict : id list -> (id * id) list -> id list
Orders the variable in the first argument such as to minimize the
sum of the distance between the position of two variables appearing in
a couple of the second list. (minimisation is approximate, this is
an heuristic so that the EDD will not explode in size when expressing
equations such as x = y && u = v && a != b)
*)
let scope_constrict vars cp_id =
let var_i = Array.of_list vars in
let n = Array.length var_i in
let i_var = Hashtbl.create n in
Array.iteri (fun i v -> Hashtbl.add i_var v i) var_i;
let cp_i = List.map
(fun (x, y) -> Hashtbl.find i_var x, Hashtbl.find i_var y)
cp_id in
let eval i =
let r = Array.make n (-1) in
Array.iteri (fun pos var -> r.(var) <- pos) i;
Array.iteri (fun _ x -> assert (x <> (-1))) r;
List.fold_left
(fun s (x, y) -> s + abs (r.(x) - r.(y)))
0 cp_i
in
let best = Array.init n (fun i -> i) in
let usefull = ref true in
Format.printf "SCA";
while !usefull do
Format.printf ".@?";
usefull := false;
let try_s x =
if eval x < eval best then begin
Array.blit x 0 best 0 n;
usefull := true
end
in
for i = 0 to n-1 do
let tt = Array.copy best in
(* move item i at beginning *)
let temp = tt.(i) in
for j = i downto 1 do tt.(j) <- tt.(j-1) done;
tt.(0) <- temp;
(* try all positions *)
try_s tt;
for j = 1 to n-1 do
let temp = tt.(j-1) in
tt.(j-1) <- tt.(j);
tt.(j) <- temp;
try_s tt
done
done
done;
Format.printf "@.";
Array.to_list (Array.map (Array.get var_i) best)
(*
force_ordering : id list -> (float * id list) list -> id list
Determine a good ordering for enumerate variables based on the FORCE algorithm
*)
let force_ordering vars groups =
let var_i = Array.of_list vars in
let n = Array.length var_i in
let i_var = Hashtbl.create n in
Array.iteri (fun i v -> Hashtbl.add i_var v i) var_i;
Hashtbl.add i_var "#BEGIN" (-1);
let ngroups = List.map
(fun (w, l) -> w, List.map (Hashtbl.find i_var) l)
groups in
let ord = Array.init n (fun i -> i) in
for iter = 0 to 500 do
let rev = Array.make n (-1) in
for i = 0 to n-1 do rev.(ord.(i)) <- i done;
let bw = Array.make n 0. in
let w = Array.make n 0. in
let gfun (gw, l) =
let sp = List.fold_left (+.) 0.
(List.map
(fun i -> if i = -1 then -.gw else float_of_int (rev.(i))) l)
in
let b = sp /. float_of_int (List.length l) in
List.iter (fun i -> if i >= 0 then begin
bw.(i) <- bw.(i) +. (gw *. b);
w.(i) <- w.(i) +. gw end)
l
in
List.iter gfun ngroups;
let b = Array.init n
(fun i ->
if w.(i) = 0. then
float_of_int i
else bw.(i) /. w.(i)) in
let ol = List.sort
(fun i j -> Pervasives.compare b.(i) b.(j))
(Array.to_list ord) in
Array.blit (Array.of_list ol) 0 ord 0 n
done;
List.map (Array.get var_i) (Array.to_list ord)
(*
Make varenv : takes a program, and extracts
- list of enum variables
- list of num variables
- good order for enum variables
- cycle, forget
*)
let mk_varenv (rp : rooted_prog) f cl =
(* add variables from LASTs *)
let last_vars = uniq_sorted
(List.sort compare (Transform.extract_last_vars f)) in
let last_vars = List.map
(fun id ->
let (_, _, ty) = List.find (fun (_, u, _) -> id = "L"^u) rp.all_vars
in false, id, ty)
last_vars in
let all_vars = last_vars @ rp.all_vars in
Format.printf "Vars: @[<hov>%a@]@.@."
(print_list Ast_printer.print_typed_var ", ")
all_vars;
let num_vars, enum_vars = List.fold_left
(fun (nv, ev) (_, id, t) -> match t with
| TEnum ch -> nv, (id, ch)::ev
| TInt -> (id, false)::nv, ev
| TReal -> (id, true)::nv, ev)
([], []) all_vars in
(* calculate order for enumerated variables *)
let evars = List.map fst enum_vars in
let lv = extract_linked_evars_root cl in
let lv = uniq_sorted
(List.sort Pervasives.compare (List.map ord_couple lv)) in
let lv_f = List.map (fun (a, b) -> (1.0, [a; b])) lv in
let lv_f = lv_f @ (List.map (fun v -> (10.0, ["#BEGIN"; v]))
(extract_const_vars_root cl)) in
let lv_f = lv_f @ (List.map (fun v -> (5.0, ["#BEGIN"; v]))
(List.filter (fun n -> is_suffix n "init") evars)) in
let lv_f = lv_f @ (List.map (fun v -> (3.0, ["#BEGIN"; v]))
(List.filter (fun n -> is_suffix n "state") evars)) in
let lv_f = lv_f @
(List.map (fun v -> (0.7, [v; "L"^v]))
(List.filter (fun n -> List.mem ("L"^n) evars) evars)) in
let evars_ord =
if true then
time "FORCE" (fun () -> force_ordering evars lv_f)
else
time "SCA" (fun () -> scope_constrict evars lv)
in
let evars_ord =
if false then
let va, vb = List.partition (fun n -> is_suffix n "init") evars_ord in
let vb, vc = List.partition (fun n -> is_suffix n "state") vb in
(List.rev va) @ vb @ vc
else
evars_ord
in
let ev_order = Hashtbl.create (List.length evars) in
List.iteri (fun i x -> Hashtbl.add ev_order x i) evars_ord;
Format.printf "Order for variables: @[<hov>[%a]@]@."
(print_list Formula_printer.print_id ", ") evars_ord;
(* calculate cycle variables and forget variables *)
let cycle = List.fold_left
(fun q (_, id, ty) ->
if id.[0] = 'L' then
(id, String.sub id 1 (String.length id - 1), ty)::q
else q)
[] last_vars
in
let forget = List.map (fun (_, id, ty) -> (id, ty)) rp.all_vars in
{ evars = enum_vars; nvars = num_vars; ev_order;
last_vars; all_vars; cycle; forget }