diff options
Diffstat (limited to 'abstract/varenv.ml')
-rw-r--r-- | abstract/varenv.ml | 252 |
1 files changed, 252 insertions, 0 deletions
diff --git a/abstract/varenv.ml b/abstract/varenv.ml new file mode 100644 index 0000000..1aa17b4 --- /dev/null +++ b/abstract/varenv.ml @@ -0,0 +1,252 @@ +open Ast +open Util +open Typing +open Formula + + + +type item = string + +type evar = id * item list +type nvar = id * bool + +type varenv = { + evars : evar list; + nvars : nvar list; + ev_order : (id, int) Hashtbl.t; + + last_vars : (bool * id * typ) list; + all_vars : (bool * id * typ) list; + + cycle : (id * id * typ) list; (* s'(x) = s(y) *) + forget : (id * typ) list; (* s'(x) not specified *) +} + + + +(* + extract_linked_evars : conslist -> (id * id) list + + Extract all pairs of enum-type variable (x, y) appearing in an + equation like x = y or x != y + + A couple may appear several times in the result. +*) +let rec extract_linked_evars_root (ecl, _, r) = + let v_ecl = List.fold_left + (fun c (_, x, v) -> match v with + | EIdent y -> (x, y)::c + | _ -> c) + [] ecl + in + v_ecl + +let rec extract_const_vars_root (ecl, _, _) = + List.fold_left + (fun l (_, x, v) -> match v with + | EItem _ -> x::l + | _ -> l) + [] ecl + + + +(* + scope_constrict : id list -> (id * id) list -> id list + + Orders the variable in the first argument such as to minimize the + sum of the distance between the position of two variables appearing in + a couple of the second list. (minimisation is approximate, this is + an heuristic so that the EDD will not explode in size when expressing + equations such as x = y && u = v && a != b) +*) +let scope_constrict vars cp_id = + let var_i = Array.of_list vars in + let n = Array.length var_i in + + let i_var = Hashtbl.create n in + Array.iteri (fun i v -> Hashtbl.add i_var v i) var_i; + + let cp_i = List.map + (fun (x, y) -> Hashtbl.find i_var x, Hashtbl.find i_var y) + cp_id in + + let eval i = + let r = Array.make n (-1) in + Array.iteri (fun pos var -> r.(var) <- pos) i; + Array.iteri (fun _ x -> assert (x <> (-1))) r; + List.fold_left + (fun s (x, y) -> s + abs (r.(x) - r.(y))) + 0 cp_i + in + + let best = Array.init n (fun i -> i) in + + let usefull = ref true in + Format.printf "SCA"; + while !usefull do + Format.printf ".@?"; + + usefull := false; + let try_s x = + if eval x < eval best then begin + Array.blit x 0 best 0 n; + usefull := true + end + in + + for i = 0 to n-1 do + let tt = Array.copy best in + (* move item i at beginning *) + let temp = tt.(i) in + for j = i downto 1 do tt.(j) <- tt.(j-1) done; + tt.(0) <- temp; + (* try all positions *) + try_s tt; + for j = 1 to n-1 do + let temp = tt.(j-1) in + tt.(j-1) <- tt.(j); + tt.(j) <- temp; + try_s tt + done + done + done; + Format.printf "@."; + + Array.to_list (Array.map (Array.get var_i) best) + + +(* + force_ordering : id list -> (float * id list) list -> id list + + Determine a good ordering for enumerate variables based on the FORCE algorithm +*) +let force_ordering vars groups = + let var_i = Array.of_list vars in + let n = Array.length var_i in + + let i_var = Hashtbl.create n in + Array.iteri (fun i v -> Hashtbl.add i_var v i) var_i; + Hashtbl.add i_var "#BEGIN" (-1); + + let ngroups = List.map + (fun (w, l) -> w, List.map (Hashtbl.find i_var) l) + groups in + + let ord = Array.init n (fun i -> i) in + + for iter = 0 to 500 do + let rev = Array.make n (-1) in + for i = 0 to n-1 do rev.(ord.(i)) <- i done; + + let bw = Array.make n 0. in + let w = Array.make n 0. in + + let gfun (gw, l) = + let sp = List.fold_left (+.) 0. + (List.map + (fun i -> if i = -1 then -.gw else float_of_int (rev.(i))) l) + in + let b = sp /. float_of_int (List.length l) in + List.iter (fun i -> if i >= 0 then begin + bw.(i) <- bw.(i) +. (gw *. b); + w.(i) <- w.(i) +. gw end) + l + in + List.iter gfun ngroups; + + let b = Array.init n + (fun i -> + if w.(i) = 0. then + float_of_int i + else bw.(i) /. w.(i)) in + + let ol = List.sort + (fun i j -> Pervasives.compare b.(i) b.(j)) + (Array.to_list ord) in + Array.blit (Array.of_list ol) 0 ord 0 n + done; + List.map (Array.get var_i) (Array.to_list ord) + + +(* + Make varenv : takes a program, and extracts + - list of enum variables + - list of num variables + - good order for enum variables + - cycle, forget +*) + +let mk_varenv (rp : rooted_prog) f cl = + (* add variables from LASTs *) + let last_vars = uniq_sorted + (List.sort compare (Transform.extract_last_vars f)) in + let last_vars = List.map + (fun id -> + let (_, _, ty) = List.find (fun (_, u, _) -> id = "L"^u) rp.all_vars + in false, id, ty) + last_vars in + let all_vars = last_vars @ rp.all_vars in + + Format.printf "Vars: @[<hov>%a@]@.@." + (print_list Ast_printer.print_typed_var ", ") + all_vars; + + let num_vars, enum_vars = List.fold_left + (fun (nv, ev) (_, id, t) -> match t with + | TEnum ch -> nv, (id, ch)::ev + | TInt -> (id, false)::nv, ev + | TReal -> (id, true)::nv, ev) + ([], []) all_vars in + + (* calculate order for enumerated variables *) + let evars = List.map fst enum_vars in + + let lv = extract_linked_evars_root cl in + let lv = uniq_sorted + (List.sort Pervasives.compare (List.map ord_couple lv)) in + + let lv_f = List.map (fun (a, b) -> (1.0, [a; b])) lv in + let lv_f = lv_f @ (List.map (fun v -> (10.0, ["#BEGIN"; v])) + (extract_const_vars_root cl)) in + let lv_f = lv_f @ (List.map (fun v -> (5.0, ["#BEGIN"; v])) + (List.filter (fun n -> is_suffix n "init") evars)) in + let lv_f = lv_f @ (List.map (fun v -> (3.0, ["#BEGIN"; v])) + (List.filter (fun n -> is_suffix n "state") evars)) in + let lv_f = lv_f @ + (List.map (fun v -> (0.7, [v; "L"^v])) + (List.filter (fun n -> List.mem ("L"^n) evars) evars)) in + let evars_ord = + if true then + time "FORCE" (fun () -> force_ordering evars lv_f) + else + time "SCA" (fun () -> scope_constrict evars lv) + in + + let evars_ord = + if false then + let va, vb = List.partition (fun n -> is_suffix n "init") evars_ord in + let vb, vc = List.partition (fun n -> is_suffix n "state") vb in + (List.rev va) @ vb @ vc + else + evars_ord + in + + let ev_order = Hashtbl.create (List.length evars) in + List.iteri (fun i x -> Hashtbl.add ev_order x i) evars_ord; + + Format.printf "Order for variables: @[<hov>[%a]@]@." + (print_list Formula_printer.print_id ", ") evars_ord; + + (* calculate cycle variables and forget variables *) + let cycle = List.fold_left + (fun q (_, id, ty) -> + if id.[0] = 'L' then + (id, String.sub id 1 (String.length id - 1), ty)::q + else q) + [] last_vars + in + let forget = List.map (fun (_, id, ty) -> (id, ty)) rp.all_vars in + + { evars = enum_vars; nvars = num_vars; ev_order; + last_vars; all_vars; cycle; forget } + |