summaryrefslogtreecommitdiff
path: root/abstract/varenv.ml
diff options
context:
space:
mode:
Diffstat (limited to 'abstract/varenv.ml')
-rw-r--r--abstract/varenv.ml252
1 files changed, 252 insertions, 0 deletions
diff --git a/abstract/varenv.ml b/abstract/varenv.ml
new file mode 100644
index 0000000..1aa17b4
--- /dev/null
+++ b/abstract/varenv.ml
@@ -0,0 +1,252 @@
+open Ast
+open Util
+open Typing
+open Formula
+
+
+
+type item = string
+
+type evar = id * item list
+type nvar = id * bool
+
+type varenv = {
+ evars : evar list;
+ nvars : nvar list;
+ ev_order : (id, int) Hashtbl.t;
+
+ last_vars : (bool * id * typ) list;
+ all_vars : (bool * id * typ) list;
+
+ cycle : (id * id * typ) list; (* s'(x) = s(y) *)
+ forget : (id * typ) list; (* s'(x) not specified *)
+}
+
+
+
+(*
+ extract_linked_evars : conslist -> (id * id) list
+
+ Extract all pairs of enum-type variable (x, y) appearing in an
+ equation like x = y or x != y
+
+ A couple may appear several times in the result.
+*)
+let rec extract_linked_evars_root (ecl, _, r) =
+ let v_ecl = List.fold_left
+ (fun c (_, x, v) -> match v with
+ | EIdent y -> (x, y)::c
+ | _ -> c)
+ [] ecl
+ in
+ v_ecl
+
+let rec extract_const_vars_root (ecl, _, _) =
+ List.fold_left
+ (fun l (_, x, v) -> match v with
+ | EItem _ -> x::l
+ | _ -> l)
+ [] ecl
+
+
+
+(*
+ scope_constrict : id list -> (id * id) list -> id list
+
+ Orders the variable in the first argument such as to minimize the
+ sum of the distance between the position of two variables appearing in
+ a couple of the second list. (minimisation is approximate, this is
+ an heuristic so that the EDD will not explode in size when expressing
+ equations such as x = y && u = v && a != b)
+*)
+let scope_constrict vars cp_id =
+ let var_i = Array.of_list vars in
+ let n = Array.length var_i in
+
+ let i_var = Hashtbl.create n in
+ Array.iteri (fun i v -> Hashtbl.add i_var v i) var_i;
+
+ let cp_i = List.map
+ (fun (x, y) -> Hashtbl.find i_var x, Hashtbl.find i_var y)
+ cp_id in
+
+ let eval i =
+ let r = Array.make n (-1) in
+ Array.iteri (fun pos var -> r.(var) <- pos) i;
+ Array.iteri (fun _ x -> assert (x <> (-1))) r;
+ List.fold_left
+ (fun s (x, y) -> s + abs (r.(x) - r.(y)))
+ 0 cp_i
+ in
+
+ let best = Array.init n (fun i -> i) in
+
+ let usefull = ref true in
+ Format.printf "SCA";
+ while !usefull do
+ Format.printf ".@?";
+
+ usefull := false;
+ let try_s x =
+ if eval x < eval best then begin
+ Array.blit x 0 best 0 n;
+ usefull := true
+ end
+ in
+
+ for i = 0 to n-1 do
+ let tt = Array.copy best in
+ (* move item i at beginning *)
+ let temp = tt.(i) in
+ for j = i downto 1 do tt.(j) <- tt.(j-1) done;
+ tt.(0) <- temp;
+ (* try all positions *)
+ try_s tt;
+ for j = 1 to n-1 do
+ let temp = tt.(j-1) in
+ tt.(j-1) <- tt.(j);
+ tt.(j) <- temp;
+ try_s tt
+ done
+ done
+ done;
+ Format.printf "@.";
+
+ Array.to_list (Array.map (Array.get var_i) best)
+
+
+(*
+ force_ordering : id list -> (float * id list) list -> id list
+
+ Determine a good ordering for enumerate variables based on the FORCE algorithm
+*)
+let force_ordering vars groups =
+ let var_i = Array.of_list vars in
+ let n = Array.length var_i in
+
+ let i_var = Hashtbl.create n in
+ Array.iteri (fun i v -> Hashtbl.add i_var v i) var_i;
+ Hashtbl.add i_var "#BEGIN" (-1);
+
+ let ngroups = List.map
+ (fun (w, l) -> w, List.map (Hashtbl.find i_var) l)
+ groups in
+
+ let ord = Array.init n (fun i -> i) in
+
+ for iter = 0 to 500 do
+ let rev = Array.make n (-1) in
+ for i = 0 to n-1 do rev.(ord.(i)) <- i done;
+
+ let bw = Array.make n 0. in
+ let w = Array.make n 0. in
+
+ let gfun (gw, l) =
+ let sp = List.fold_left (+.) 0.
+ (List.map
+ (fun i -> if i = -1 then -.gw else float_of_int (rev.(i))) l)
+ in
+ let b = sp /. float_of_int (List.length l) in
+ List.iter (fun i -> if i >= 0 then begin
+ bw.(i) <- bw.(i) +. (gw *. b);
+ w.(i) <- w.(i) +. gw end)
+ l
+ in
+ List.iter gfun ngroups;
+
+ let b = Array.init n
+ (fun i ->
+ if w.(i) = 0. then
+ float_of_int i
+ else bw.(i) /. w.(i)) in
+
+ let ol = List.sort
+ (fun i j -> Pervasives.compare b.(i) b.(j))
+ (Array.to_list ord) in
+ Array.blit (Array.of_list ol) 0 ord 0 n
+ done;
+ List.map (Array.get var_i) (Array.to_list ord)
+
+
+(*
+ Make varenv : takes a program, and extracts
+ - list of enum variables
+ - list of num variables
+ - good order for enum variables
+ - cycle, forget
+*)
+
+let mk_varenv (rp : rooted_prog) f cl =
+ (* add variables from LASTs *)
+ let last_vars = uniq_sorted
+ (List.sort compare (Transform.extract_last_vars f)) in
+ let last_vars = List.map
+ (fun id ->
+ let (_, _, ty) = List.find (fun (_, u, _) -> id = "L"^u) rp.all_vars
+ in false, id, ty)
+ last_vars in
+ let all_vars = last_vars @ rp.all_vars in
+
+ Format.printf "Vars: @[<hov>%a@]@.@."
+ (print_list Ast_printer.print_typed_var ", ")
+ all_vars;
+
+ let num_vars, enum_vars = List.fold_left
+ (fun (nv, ev) (_, id, t) -> match t with
+ | TEnum ch -> nv, (id, ch)::ev
+ | TInt -> (id, false)::nv, ev
+ | TReal -> (id, true)::nv, ev)
+ ([], []) all_vars in
+
+ (* calculate order for enumerated variables *)
+ let evars = List.map fst enum_vars in
+
+ let lv = extract_linked_evars_root cl in
+ let lv = uniq_sorted
+ (List.sort Pervasives.compare (List.map ord_couple lv)) in
+
+ let lv_f = List.map (fun (a, b) -> (1.0, [a; b])) lv in
+ let lv_f = lv_f @ (List.map (fun v -> (10.0, ["#BEGIN"; v]))
+ (extract_const_vars_root cl)) in
+ let lv_f = lv_f @ (List.map (fun v -> (5.0, ["#BEGIN"; v]))
+ (List.filter (fun n -> is_suffix n "init") evars)) in
+ let lv_f = lv_f @ (List.map (fun v -> (3.0, ["#BEGIN"; v]))
+ (List.filter (fun n -> is_suffix n "state") evars)) in
+ let lv_f = lv_f @
+ (List.map (fun v -> (0.7, [v; "L"^v]))
+ (List.filter (fun n -> List.mem ("L"^n) evars) evars)) in
+ let evars_ord =
+ if true then
+ time "FORCE" (fun () -> force_ordering evars lv_f)
+ else
+ time "SCA" (fun () -> scope_constrict evars lv)
+ in
+
+ let evars_ord =
+ if false then
+ let va, vb = List.partition (fun n -> is_suffix n "init") evars_ord in
+ let vb, vc = List.partition (fun n -> is_suffix n "state") vb in
+ (List.rev va) @ vb @ vc
+ else
+ evars_ord
+ in
+
+ let ev_order = Hashtbl.create (List.length evars) in
+ List.iteri (fun i x -> Hashtbl.add ev_order x i) evars_ord;
+
+ Format.printf "Order for variables: @[<hov>[%a]@]@."
+ (print_list Formula_printer.print_id ", ") evars_ord;
+
+ (* calculate cycle variables and forget variables *)
+ let cycle = List.fold_left
+ (fun q (_, id, ty) ->
+ if id.[0] = 'L' then
+ (id, String.sub id 1 (String.length id - 1), ty)::q
+ else q)
+ [] last_vars
+ in
+ let forget = List.map (fun (_, id, ty) -> (id, ty)) rp.all_vars in
+
+ { evars = enum_vars; nvars = num_vars; ev_order;
+ last_vars; all_vars; cycle; forget }
+