summaryrefslogblamecommitdiff
path: root/abstract/varenv.ml
blob: 1aa17b45ec6c972231ffe4733b17b5de1084af2a (plain) (tree)



























































































































































































































































                                                                                
open Ast
open Util
open Typing
open Formula



type item = string

type evar = id * item list
type nvar = id * bool

type varenv = {
    evars         : evar list;
    nvars         : nvar list;
    ev_order      : (id, int) Hashtbl.t;

    last_vars     : (bool * id * typ) list;
    all_vars      : (bool * id * typ) list;

    cycle         : (id * id * typ) list;     (* s'(x) = s(y) *)
    forget        : (id * typ) list;          (* s'(x) not specified *)
}



(*
  extract_linked_evars : conslist -> (id * id) list

  Extract all pairs of enum-type variable (x, y) appearing in an
  equation like x = y or x != y

  A couple may appear several times in the result.
*)
let rec extract_linked_evars_root (ecl, _, r) =
    let v_ecl = List.fold_left
        (fun c (_, x, v) -> match v with
            | EIdent y -> (x, y)::c
            | _ -> c)
        [] ecl
    in
    v_ecl

let rec extract_const_vars_root (ecl, _, _) =
    List.fold_left
      (fun l (_, x, v) -> match v with 
            | EItem _ -> x::l
            | _ -> l)
      [] ecl



(*
  scope_constrict : id list -> (id * id) list -> id list

  Orders the variable in the first argument such as to minimize the
  sum of the distance between the position of two variables appearing in
  a couple of the second list. (minimisation is approximate, this is
  an heuristic so that the EDD will not explode in size when expressing
  equations such as x = y && u = v && a != b)
*)
let scope_constrict vars cp_id =
    let var_i = Array.of_list vars in
    let n = Array.length var_i in

    let i_var = Hashtbl.create n in
    Array.iteri (fun i v -> Hashtbl.add i_var v i) var_i;

    let cp_i = List.map
      (fun (x, y) -> Hashtbl.find i_var x, Hashtbl.find i_var y)
      cp_id in

    let eval i =
      let r = Array.make n (-1) in
      Array.iteri (fun pos var -> r.(var) <- pos) i;
      Array.iteri (fun _ x -> assert (x <> (-1))) r;
      List.fold_left
        (fun s (x, y) -> s + abs (r.(x) - r.(y)))
        0 cp_i
    in

    let best = Array.init n (fun i -> i) in

    let usefull = ref true in
    Format.printf "SCA";
    while !usefull do
      Format.printf ".@?";

      usefull := false;
      let try_s x =
        if eval x < eval best then begin
          Array.blit x 0 best 0 n;
          usefull := true
        end
      in

      for i = 0 to n-1 do
        let tt = Array.copy best in
        (* move item i at beginning *)
        let temp = tt.(i) in
        for j = i downto 1 do tt.(j) <- tt.(j-1) done;
        tt.(0) <- temp;
        (* try all positions *)
        try_s tt;
        for j = 1 to n-1 do
          let temp = tt.(j-1) in
          tt.(j-1) <- tt.(j);
          tt.(j) <- temp;
          try_s tt
        done
      done
    done;
    Format.printf "@.";

    Array.to_list (Array.map (Array.get var_i) best)


(*
  force_ordering : id list -> (float * id list) list -> id list

  Determine a good ordering for enumerate variables based on the FORCE algorithm
*)
let force_ordering vars groups =
    let var_i = Array.of_list vars in
    let n = Array.length var_i in

    let i_var = Hashtbl.create n in
    Array.iteri (fun i v -> Hashtbl.add i_var v i) var_i;
    Hashtbl.add i_var "#BEGIN" (-1);

    let ngroups = List.map
      (fun (w, l) -> w, List.map (Hashtbl.find i_var) l)
      groups in

    let ord = Array.init n (fun i -> i) in

    for iter = 0 to 500 do
        let rev = Array.make n (-1) in
        for i = 0 to n-1 do rev.(ord.(i)) <- i done;

        let bw = Array.make n 0. in
        let w = Array.make n 0. in

        let gfun (gw, l) =
          let sp = List.fold_left (+.) 0.
            (List.map
              (fun i -> if i = -1 then -.gw else float_of_int (rev.(i))) l)
          in
          let b = sp /. float_of_int (List.length l) in
          List.iter (fun i -> if i >= 0 then begin
                      bw.(i) <- bw.(i) +. (gw *. b);
                      w.(i) <- w.(i) +. gw end)
              l 
        in
        List.iter gfun ngroups;

        let b = Array.init n
          (fun i ->
            if w.(i) = 0. then
                float_of_int i
            else bw.(i) /. w.(i)) in
                
        let ol = List.sort
          (fun i j -> Pervasives.compare b.(i) b.(j))
          (Array.to_list ord) in
        Array.blit (Array.of_list ol) 0 ord 0 n
    done;
    List.map (Array.get var_i) (Array.to_list ord)


(*
  Make varenv : takes a program, and extracts
  - list of enum variables
  - list of num variables
  - good order for enum variables
  - cycle, forget
*)

let mk_varenv (rp : rooted_prog) f cl =
    (* add variables from LASTs *)
    let last_vars = uniq_sorted 
      (List.sort compare (Transform.extract_last_vars f)) in
    let last_vars = List.map
      (fun id ->
        let (_, _, ty) = List.find (fun (_, u, _) -> id = "L"^u) rp.all_vars
          in false, id, ty)
      last_vars in
    let all_vars = last_vars @ rp.all_vars in

    Format.printf "Vars: @[<hov>%a@]@.@."
        (print_list Ast_printer.print_typed_var ", ")
        all_vars;

    let num_vars, enum_vars = List.fold_left
        (fun (nv, ev) (_, id, t) -> match t with
            | TEnum ch -> nv, (id, ch)::ev
            | TInt -> (id, false)::nv, ev
            | TReal -> (id, true)::nv, ev)
        ([], []) all_vars in

    (* calculate order for enumerated variables *)
    let evars = List.map fst enum_vars in

    let lv = extract_linked_evars_root cl in
    let lv = uniq_sorted
         (List.sort Pervasives.compare (List.map ord_couple lv)) in

    let lv_f = List.map (fun (a, b) -> (1.0, [a; b])) lv in
    let lv_f = lv_f @ (List.map (fun v -> (10.0, ["#BEGIN"; v]))
      (extract_const_vars_root cl)) in
    let lv_f = lv_f @ (List.map (fun v -> (5.0, ["#BEGIN"; v]))
      (List.filter (fun n -> is_suffix n "init") evars)) in
    let lv_f = lv_f @ (List.map (fun v -> (3.0, ["#BEGIN"; v]))
      (List.filter (fun n -> is_suffix n "state") evars)) in
    let lv_f = lv_f @
      (List.map (fun v -> (0.7, [v; "L"^v]))
        (List.filter (fun n -> List.mem ("L"^n) evars) evars)) in
    let evars_ord =
      if true then
        time "FORCE" (fun () -> force_ordering evars lv_f)
      else
        time "SCA" (fun () -> scope_constrict evars lv)
    in

    let evars_ord =
      if false then
        let va, vb = List.partition (fun n -> is_suffix n "init") evars_ord in
        let vb, vc = List.partition (fun n -> is_suffix n "state") vb in
        (List.rev va) @ vb @ vc
      else
        evars_ord
    in

    let ev_order = Hashtbl.create (List.length evars) in
    List.iteri (fun i x -> Hashtbl.add ev_order x i) evars_ord;

    Format.printf "Order for variables: @[<hov>[%a]@]@."
      (print_list Formula_printer.print_id ", ") evars_ord;

    (* calculate cycle variables and forget variables *)
    let cycle = List.fold_left
      (fun q (_, id, ty) ->
          if id.[0] = 'L' then
            (id, String.sub id 1 (String.length id - 1), ty)::q
          else q)
      [] last_vars
    in
    let forget = List.map (fun (_, id, ty) -> (id, ty)) rp.all_vars in

    { evars = enum_vars; nvars = num_vars; ev_order;
      last_vars; all_vars; cycle; forget }