summaryrefslogtreecommitdiff
path: root/abstract/varenv.ml
blob: bcfa77a8636248cf31030b3d5a27066539bc5b5d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
open Ast
open Util
open Typing
open Formula



type item = string

type evar = id * item list
type nvar = id * bool

type varenv = {
    evars         : evar list;
    nvars         : nvar list;
    ev_order      : (id, int) Hashtbl.t;

    last_vars     : (bool * id * typ) list;
    all_vars      : (bool * id * typ) list;
    d_vars        : id list;

    cycle         : (id * id * typ) list;     (* s'(x) = s(y) *)
    forget        : (id * typ) list;          (* s'(x) not specified *)
    forget_inv    : (id * typ) list;
}




(*
    Extract variables accessed by a LAST
*)

let rec extract_last_vars = function
  | BRel(_, a, b, _) ->
      elv_ne a @ elv_ne b
  | BEnumCons c ->
      elv_ec c
  | BAnd(a, b) | BOr (a, b) ->
    extract_last_vars a @ extract_last_vars b
  | BNot(e) -> extract_last_vars e
  | BTernary(c, a, b) -> extract_last_vars c @ 
      extract_last_vars a @ extract_last_vars b
  | _ -> []

and elv_ne = function
  | NIdent i when i.[0] = 'L' -> [i]
  | NBinary(_, a, b, _) -> elv_ne a @ elv_ne b
  | NUnary (_, a, _) -> elv_ne a
  | _ -> []

and elv_ec (_, v, q) =
  (if v.[0] = 'L' then [v] else []) @
  (match q with
    | EIdent i when i.[0] = 'L' -> [i]
    | _ -> [])



(*
  extract_linked_evars : conslist -> (id * id) list

  Extract all pairs of enum-type variable (x, y) appearing in an
  equation like x = y or x != y

  A couple may appear several times in the result.
*)
let rec extract_linked_evars_root (ecl, _, r) =
    let v_ecl = List.fold_left
        (fun c (_, x, v) -> match v with
            | EIdent y -> (x, y)::c
            | _ -> c)
        [] ecl
    in
    v_ecl

let rec extract_const_vars_root (ecl, _, _) =
    List.fold_left
      (fun l (_, x, v) -> match v with 
            | EItem _ -> x::l
            | _ -> l)
      [] ecl


let extract_choice_groups f =
  let rec aux w = function
      | BNot n -> aux w n
      | BRel _ | BConst _ -> [], []
      | BEnumCons(_, x, EItem _) -> [], [x]
      | BEnumCons(_, x, EIdent y) -> [], [y]
      | BAnd(a, b) | BOr(a, b) ->
        let ga, va = aux w a in
        let gb, vb = aux w b in
        ga@gb, va@vb
      | BTernary(c, a, b) ->
        let gc, vc = aux (w /. 3.) c in
        let ga, va = aux (w /. 2.) a in
        let gb, vb = aux (w /. 2.) b in
        let v = uniq_sorted (List.sort compare (vc@va@vb)) in
        (w, v)::(gc@ga@gb), v
  in
  fst (aux 0.6 f)



(*
  scope_constrict : id list -> (id * id) list -> id list

  Orders the variable in the first argument such as to minimize the
  sum of the distance between the position of two variables appearing in
  a couple of the second list. (minimisation is approximate, this is
  an heuristic so that the EDD will not explode in size when expressing
  equations such as x = y && u = v && a != b)
*)
let scope_constrict vars cp_id =
    let var_i = Array.of_list vars in
    let n = Array.length var_i in

    let i_var = Hashtbl.create n in
    Array.iteri (fun i v -> Hashtbl.add i_var v i) var_i;

    let cp_i = List.map
      (fun (x, y) -> Hashtbl.find i_var x, Hashtbl.find i_var y)
      cp_id in

    let eval i =
      let r = Array.make n (-1) in
      Array.iteri (fun pos var -> r.(var) <- pos) i;
      Array.iteri (fun _ x -> assert (x <> (-1))) r;
      List.fold_left
        (fun s (x, y) -> s + abs (r.(x) - r.(y)))
        0 cp_i
    in

    let best = Array.init n (fun i -> i) in

    let usefull = ref true in
    Format.printf "SCA";
    while !usefull do
      Format.printf ".@?";

      usefull := false;
      let try_s x =
        if eval x < eval best then begin
          Array.blit x 0 best 0 n;
          usefull := true
        end
      in

      for i = 0 to n-1 do
        let tt = Array.copy best in
        (* move item i at beginning *)
        let temp = tt.(i) in
        for j = i downto 1 do tt.(j) <- tt.(j-1) done;
        tt.(0) <- temp;
        (* try all positions *)
        try_s tt;
        for j = 1 to n-1 do
          let temp = tt.(j-1) in
          tt.(j-1) <- tt.(j);
          tt.(j) <- temp;
          try_s tt
        done
      done
    done;
    Format.printf "@.";

    Array.to_list (Array.map (Array.get var_i) best)


(*
  force_ordering : id list -> (float * id list) list -> id list

  Determine a good ordering for enumerate variables based on the FORCE algorithm
*)
let force_ordering vars groups =
    let var_i = Array.of_list vars in
    let n = Array.length var_i in

    let i_var = Hashtbl.create n in
    Array.iteri (fun i v -> Hashtbl.add i_var v i) var_i;
    Hashtbl.add i_var "#BEGIN" (-1);

    let ngroups = List.map
      (fun (w, l) -> w, List.map (Hashtbl.find i_var) l)
      groups in

    let ord = Array.init n (fun i -> i) in

    for iter = 0 to 500 do
        let rev = Array.make n (-1) in
        for i = 0 to n-1 do rev.(ord.(i)) <- i done;

        let bw = Array.make n 0. in
        let w = Array.make n 0. in

        let gfun (gw, l) =
          let sp = List.fold_left (+.) 0.
            (List.map
              (fun i -> if i = -1 then -.gw else float_of_int (rev.(i))) l)
          in
          let b = sp /. float_of_int (List.length l) in
          List.iter (fun i -> if i >= 0 then begin
                      bw.(i) <- bw.(i) +. (gw *. b);
                      w.(i) <- w.(i) +. gw end)
              l 
        in
        List.iter gfun ngroups;

        let b = Array.init n
          (fun i ->
            if w.(i) = 0. then
                float_of_int i
            else bw.(i) /. w.(i)) in
                
        let ol = List.sort
          (fun i j -> Pervasives.compare b.(i) b.(j))
          (Array.to_list ord) in
        Array.blit (Array.of_list ol) 0 ord 0 n
    done;
    List.map (Array.get var_i) (Array.to_list ord)


(*
  Make varenv : takes a program, and extracts
  - list of enum variables
  - list of num variables
  - good order for enum variables
  - cycle, forget
*)

let mk_varenv (rp : rooted_prog) disj_fun f cl =
    (* add variables from LASTs *)
    let last_vars = uniq_sorted 
      (List.sort compare (extract_last_vars f)) in
    let last_vars = List.map
      (fun id ->
        let (_, _, ty) = List.find (fun (_, u, _) -> id = "L"^u) rp.all_vars
          in false, id, ty)
      last_vars in
    let all_vars = last_vars @ rp.all_vars in

    Format.printf "Vars: @[<hov>%a@]@.@."
        (print_list Ast_printer.print_typed_var ", ")
        all_vars;

    let num_vars, enum_vars = List.fold_left
        (fun (nv, ev) (_, id, t) -> match t with
            | TEnum ch -> nv, (id, ch)::ev
            | TInt -> (id, false)::nv, ev
            | TReal -> (id, true)::nv, ev)
        ([], []) all_vars in

    (* calculate order for enumerated variables *)
    let evars = List.map fst enum_vars in

    let lv = extract_linked_evars_root cl in
    let lv = uniq_sorted
         (List.sort Pervasives.compare (List.map ord_couple lv)) in

    let lv_f = List.map (fun (a, b) -> (1.0, [a; b])) lv in
    let lv_f = lv_f @ (List.map (fun v -> (10.0, ["#BEGIN"; v]))
      (extract_const_vars_root cl)) in
    let lv_f = lv_f @ (List.map (fun v -> (7.0, ["#BEGIN"; v]))
      (List.filter (fun n -> is_suffix n "init") evars)) in
    let lv_f = lv_f @ (List.map (fun v -> (3.0, ["#BEGIN"; v]))
      (List.filter (fun n -> is_suffix n "act" || is_suffix n "state") evars)) in
    let lv_f = lv_f @
      (List.map (fun v -> (0.7, [v; "L"^v]))
        (List.filter (fun n -> List.mem ("L"^n) evars) evars)) in
    let lv_f = lv_f @ (extract_choice_groups f) in
    let evars_ord =
      if true then
        time "FORCE" (fun () -> force_ordering evars lv_f)
      else
        time "SCA" (fun () -> scope_constrict evars lv)
    in

    let evars_ord =
      if false then
        let va, vb = List.partition (fun n -> is_suffix n "init") evars_ord in
        let vb, vc = List.partition (fun n -> is_suffix n "state") vb in
        (List.rev va) @ vb @ vc
      else
        evars_ord
    in

    let ev_order = Hashtbl.create (List.length evars) in
    List.iteri (fun i x -> Hashtbl.add ev_order x i) evars_ord;

    let enum_vars = List.sort
      (fun (id1, _) (id2, _) ->
        compare (Hashtbl.find ev_order id1) (Hashtbl.find ev_order id2))
      enum_vars
    in

    Format.printf "Order for variables: @[<hov>[%a]@]@."
      (print_list Formula_printer.print_id ", ") evars_ord;

    (* calculate cycle variables and forget variables *)
    let cycle = List.fold_left
      (fun q (_, id, ty) ->
          if id.[0] = 'L' then
            (id, String.sub id 1 (String.length id - 1), ty)::q
          else q)
      [] last_vars
    in
    let forget = List.map (fun (_, id, ty) -> (id, ty)) rp.all_vars in
    let forget_inv = List.map (fun (_, id, ty) -> (id, ty))
      (List.filter
        (fun (_, id, _) ->
          not (List.exists (fun (_, b, _) -> b = id) cycle))
        all_vars) in

    (* use specified disjunction variables *)
    let d_vars = List.filter disj_fun
        (List.map (fun (id, _) -> id) enum_vars) in
    Format.printf "Disjunction variables: @[<hov>[%a]@]@."
      (print_list Formula_printer.print_id ", ") d_vars;

    { evars = enum_vars; nvars = num_vars; ev_order; d_vars;
      last_vars; all_vars; cycle; forget; forget_inv }