summaryrefslogblamecommitdiff
path: root/abstract/varenv.ml
blob: bcfa77a8636248cf31030b3d5a27066539bc5b5d (plain) (tree)


















                                           
                            


                                                                       
                                    



 






























                                               
























                                                                



















                                                             






























































































































                                                                                
                                                

                                  
                                                  



























                                                                            
                                                               

                                                               
                                                                                 


                                                                 
                                                  


















                                                                              





                                                                        











                                                                      




                                                            
 






                                                            
                                                      
 
open Ast
open Util
open Typing
open Formula



type item = string

type evar = id * item list
type nvar = id * bool

type varenv = {
    evars         : evar list;
    nvars         : nvar list;
    ev_order      : (id, int) Hashtbl.t;

    last_vars     : (bool * id * typ) list;
    all_vars      : (bool * id * typ) list;
    d_vars        : id list;

    cycle         : (id * id * typ) list;     (* s'(x) = s(y) *)
    forget        : (id * typ) list;          (* s'(x) not specified *)
    forget_inv    : (id * typ) list;
}




(*
    Extract variables accessed by a LAST
*)

let rec extract_last_vars = function
  | BRel(_, a, b, _) ->
      elv_ne a @ elv_ne b
  | BEnumCons c ->
      elv_ec c
  | BAnd(a, b) | BOr (a, b) ->
    extract_last_vars a @ extract_last_vars b
  | BNot(e) -> extract_last_vars e
  | BTernary(c, a, b) -> extract_last_vars c @ 
      extract_last_vars a @ extract_last_vars b
  | _ -> []

and elv_ne = function
  | NIdent i when i.[0] = 'L' -> [i]
  | NBinary(_, a, b, _) -> elv_ne a @ elv_ne b
  | NUnary (_, a, _) -> elv_ne a
  | _ -> []

and elv_ec (_, v, q) =
  (if v.[0] = 'L' then [v] else []) @
  (match q with
    | EIdent i when i.[0] = 'L' -> [i]
    | _ -> [])



(*
  extract_linked_evars : conslist -> (id * id) list

  Extract all pairs of enum-type variable (x, y) appearing in an
  equation like x = y or x != y

  A couple may appear several times in the result.
*)
let rec extract_linked_evars_root (ecl, _, r) =
    let v_ecl = List.fold_left
        (fun c (_, x, v) -> match v with
            | EIdent y -> (x, y)::c
            | _ -> c)
        [] ecl
    in
    v_ecl

let rec extract_const_vars_root (ecl, _, _) =
    List.fold_left
      (fun l (_, x, v) -> match v with 
            | EItem _ -> x::l
            | _ -> l)
      [] ecl


let extract_choice_groups f =
  let rec aux w = function
      | BNot n -> aux w n
      | BRel _ | BConst _ -> [], []
      | BEnumCons(_, x, EItem _) -> [], [x]
      | BEnumCons(_, x, EIdent y) -> [], [y]
      | BAnd(a, b) | BOr(a, b) ->
        let ga, va = aux w a in
        let gb, vb = aux w b in
        ga@gb, va@vb
      | BTernary(c, a, b) ->
        let gc, vc = aux (w /. 3.) c in
        let ga, va = aux (w /. 2.) a in
        let gb, vb = aux (w /. 2.) b in
        let v = uniq_sorted (List.sort compare (vc@va@vb)) in
        (w, v)::(gc@ga@gb), v
  in
  fst (aux 0.6 f)



(*
  scope_constrict : id list -> (id * id) list -> id list

  Orders the variable in the first argument such as to minimize the
  sum of the distance between the position of two variables appearing in
  a couple of the second list. (minimisation is approximate, this is
  an heuristic so that the EDD will not explode in size when expressing
  equations such as x = y && u = v && a != b)
*)
let scope_constrict vars cp_id =
    let var_i = Array.of_list vars in
    let n = Array.length var_i in

    let i_var = Hashtbl.create n in
    Array.iteri (fun i v -> Hashtbl.add i_var v i) var_i;

    let cp_i = List.map
      (fun (x, y) -> Hashtbl.find i_var x, Hashtbl.find i_var y)
      cp_id in

    let eval i =
      let r = Array.make n (-1) in
      Array.iteri (fun pos var -> r.(var) <- pos) i;
      Array.iteri (fun _ x -> assert (x <> (-1))) r;
      List.fold_left
        (fun s (x, y) -> s + abs (r.(x) - r.(y)))
        0 cp_i
    in

    let best = Array.init n (fun i -> i) in

    let usefull = ref true in
    Format.printf "SCA";
    while !usefull do
      Format.printf ".@?";

      usefull := false;
      let try_s x =
        if eval x < eval best then begin
          Array.blit x 0 best 0 n;
          usefull := true
        end
      in

      for i = 0 to n-1 do
        let tt = Array.copy best in
        (* move item i at beginning *)
        let temp = tt.(i) in
        for j = i downto 1 do tt.(j) <- tt.(j-1) done;
        tt.(0) <- temp;
        (* try all positions *)
        try_s tt;
        for j = 1 to n-1 do
          let temp = tt.(j-1) in
          tt.(j-1) <- tt.(j);
          tt.(j) <- temp;
          try_s tt
        done
      done
    done;
    Format.printf "@.";

    Array.to_list (Array.map (Array.get var_i) best)


(*
  force_ordering : id list -> (float * id list) list -> id list

  Determine a good ordering for enumerate variables based on the FORCE algorithm
*)
let force_ordering vars groups =
    let var_i = Array.of_list vars in
    let n = Array.length var_i in

    let i_var = Hashtbl.create n in
    Array.iteri (fun i v -> Hashtbl.add i_var v i) var_i;
    Hashtbl.add i_var "#BEGIN" (-1);

    let ngroups = List.map
      (fun (w, l) -> w, List.map (Hashtbl.find i_var) l)
      groups in

    let ord = Array.init n (fun i -> i) in

    for iter = 0 to 500 do
        let rev = Array.make n (-1) in
        for i = 0 to n-1 do rev.(ord.(i)) <- i done;

        let bw = Array.make n 0. in
        let w = Array.make n 0. in

        let gfun (gw, l) =
          let sp = List.fold_left (+.) 0.
            (List.map
              (fun i -> if i = -1 then -.gw else float_of_int (rev.(i))) l)
          in
          let b = sp /. float_of_int (List.length l) in
          List.iter (fun i -> if i >= 0 then begin
                      bw.(i) <- bw.(i) +. (gw *. b);
                      w.(i) <- w.(i) +. gw end)
              l 
        in
        List.iter gfun ngroups;

        let b = Array.init n
          (fun i ->
            if w.(i) = 0. then
                float_of_int i
            else bw.(i) /. w.(i)) in
                
        let ol = List.sort
          (fun i j -> Pervasives.compare b.(i) b.(j))
          (Array.to_list ord) in
        Array.blit (Array.of_list ol) 0 ord 0 n
    done;
    List.map (Array.get var_i) (Array.to_list ord)


(*
  Make varenv : takes a program, and extracts
  - list of enum variables
  - list of num variables
  - good order for enum variables
  - cycle, forget
*)

let mk_varenv (rp : rooted_prog) disj_fun f cl =
    (* add variables from LASTs *)
    let last_vars = uniq_sorted 
      (List.sort compare (extract_last_vars f)) in
    let last_vars = List.map
      (fun id ->
        let (_, _, ty) = List.find (fun (_, u, _) -> id = "L"^u) rp.all_vars
          in false, id, ty)
      last_vars in
    let all_vars = last_vars @ rp.all_vars in

    Format.printf "Vars: @[<hov>%a@]@.@."
        (print_list Ast_printer.print_typed_var ", ")
        all_vars;

    let num_vars, enum_vars = List.fold_left
        (fun (nv, ev) (_, id, t) -> match t with
            | TEnum ch -> nv, (id, ch)::ev
            | TInt -> (id, false)::nv, ev
            | TReal -> (id, true)::nv, ev)
        ([], []) all_vars in

    (* calculate order for enumerated variables *)
    let evars = List.map fst enum_vars in

    let lv = extract_linked_evars_root cl in
    let lv = uniq_sorted
         (List.sort Pervasives.compare (List.map ord_couple lv)) in

    let lv_f = List.map (fun (a, b) -> (1.0, [a; b])) lv in
    let lv_f = lv_f @ (List.map (fun v -> (10.0, ["#BEGIN"; v]))
      (extract_const_vars_root cl)) in
    let lv_f = lv_f @ (List.map (fun v -> (7.0, ["#BEGIN"; v]))
      (List.filter (fun n -> is_suffix n "init") evars)) in
    let lv_f = lv_f @ (List.map (fun v -> (3.0, ["#BEGIN"; v]))
      (List.filter (fun n -> is_suffix n "act" || is_suffix n "state") evars)) in
    let lv_f = lv_f @
      (List.map (fun v -> (0.7, [v; "L"^v]))
        (List.filter (fun n -> List.mem ("L"^n) evars) evars)) in
    let lv_f = lv_f @ (extract_choice_groups f) in
    let evars_ord =
      if true then
        time "FORCE" (fun () -> force_ordering evars lv_f)
      else
        time "SCA" (fun () -> scope_constrict evars lv)
    in

    let evars_ord =
      if false then
        let va, vb = List.partition (fun n -> is_suffix n "init") evars_ord in
        let vb, vc = List.partition (fun n -> is_suffix n "state") vb in
        (List.rev va) @ vb @ vc
      else
        evars_ord
    in

    let ev_order = Hashtbl.create (List.length evars) in
    List.iteri (fun i x -> Hashtbl.add ev_order x i) evars_ord;

    let enum_vars = List.sort
      (fun (id1, _) (id2, _) ->
        compare (Hashtbl.find ev_order id1) (Hashtbl.find ev_order id2))
      enum_vars
    in

    Format.printf "Order for variables: @[<hov>[%a]@]@."
      (print_list Formula_printer.print_id ", ") evars_ord;

    (* calculate cycle variables and forget variables *)
    let cycle = List.fold_left
      (fun q (_, id, ty) ->
          if id.[0] = 'L' then
            (id, String.sub id 1 (String.length id - 1), ty)::q
          else q)
      [] last_vars
    in
    let forget = List.map (fun (_, id, ty) -> (id, ty)) rp.all_vars in
    let forget_inv = List.map (fun (_, id, ty) -> (id, ty))
      (List.filter
        (fun (_, id, _) ->
          not (List.exists (fun (_, b, _) -> b = id) cycle))
        all_vars) in

    (* use specified disjunction variables *)
    let d_vars = List.filter disj_fun
        (List.map (fun (id, _) -> id) enum_vars) in
    Format.printf "Disjunction variables: @[<hov>[%a]@]@."
      (print_list Formula_printer.print_id ", ") d_vars;

    { evars = enum_vars; nvars = num_vars; ev_order; d_vars;
      last_vars; all_vars; cycle; forget; forget_inv }