open Abstract_syntax_tree
open Environment_domain
open Util
module Make (E : ENVIRONMENT_DOMAIN) = struct
let neg e =
(AST_unary(AST_NOT, e)), snd e
let binop op e e2 =
(AST_binary (op, e, e2)), snd e
let m1 e =
binop AST_MINUS e (AST_int_const("1", snd e), snd e)
let p1 e =
binop AST_PLUS e (AST_int_const("1", snd e), snd e)
let bottom_with_vars vlist =
List.fold_left E.addvar E.bottom vlist
let rec condition cond env =
begin match fst cond with
| AST_binary (AST_LESS_EQUAL, e1, e2) ->
E.compare_leq env e1 e2
| AST_binary (AST_EQUAL, e1, e2) ->
E.compare_eq env e1 e2
| AST_binary (AST_AND, e1, e2) ->
E.meet (condition e1 env) (condition e2 env)
| AST_binary (AST_OR, e1, e2) ->
E.join (condition e1 env) (condition e2 env)
| AST_bool_const true -> env
| AST_bool_const false -> E.bottom
(* transformations : remove not *)
| AST_unary(AST_NOT, (AST_bool_const x, _)) ->
condition (AST_bool_const (not x), snd cond) env
| AST_unary (AST_NOT, (AST_unary(AST_NOT, cond), _)) ->
condition cond env
| AST_unary (AST_NOT, (AST_binary(AST_AND, e1, e2), x)) ->
condition
(AST_binary(AST_OR, neg e1, neg e2), x) env
| AST_unary (AST_NOT, (AST_binary(AST_OR, e1, e2), x)) ->
condition
(AST_binary(AST_AND, neg e1, neg e2), x) env
| AST_unary (AST_NOT, (AST_binary(op, e1, e2), _)) ->
let op2 = match op with
| AST_LESS_EQUAL -> AST_GREATER
| AST_LESS -> AST_GREATER_EQUAL
| AST_GREATER_EQUAL -> AST_LESS
| AST_GREATER -> AST_LESS_EQUAL
| AST_EQUAL -> AST_NOT_EQUAL
| AST_NOT_EQUAL -> AST_EQUAL
| _ -> assert false
in
condition (binop op2 e1 e2) env
(* transformations : encode everything with leq *)
| AST_binary(AST_LESS, e1, e2) ->
condition
(binop AST_AND (binop AST_LESS_EQUAL e1 (m1 e2))
(binop AST_LESS_EQUAL (p1 e1) e2))
env
| AST_binary (AST_GREATER_EQUAL, e1, e2) ->
condition
(binop AST_LESS_EQUAL e2 e1)
env
| AST_binary (AST_GREATER, e1, e2) ->
condition
(binop AST_LESS e2 e1)
env
| AST_binary (AST_NOT_EQUAL, e1, e2) ->
condition
(binop AST_OR (binop AST_LESS e1 e2) (binop AST_LESS e2 e1))
env
| _ -> env
end
let rec interp_stmt env stat =
begin match fst stat with
| AST_block b ->
(* remember to remove vars that have gone out of scope at the end *)
let prevars = E.vars env in
let env2 = List.fold_left interp_stmt env b in
let postvars = E.vars env2 in
let rmvars = List.filter (fun x -> not (List.mem x prevars)) postvars in
List.fold_left E.rmvar env2 rmvars
| AST_assign ((id, _), exp) ->
E.assign env id exp
| AST_if (cond, tb, None) ->
E.join
(interp_stmt (condition cond env) tb)
(condition (neg cond) env)
| AST_if (cond, tb, Some eb) ->
let e1 = interp_stmt (condition cond env) tb in
let e2 = interp_stmt (condition (neg cond) env) eb in
E.join e1 e2
| AST_while (cond, body) ->
(* loop unrolling *)
let rec unroll u = function
| 0 -> u, bottom_with_vars (E.vars env)
| n ->
let prev_u, u_prev_u = unroll u (n-1) in
interp_stmt (condition cond prev_u) body,
E.join u_prev_u (condition (neg cond) prev_u)
in
let env, u_u = unroll env 3 in
(* widening *)
let widen_delay = 3 in
let fsharp i =
let next_step = interp_stmt (condition cond i) body in
E.join env next_step
in
let rec iter n i =
let i' =
(if n < widen_delay then E.join else E.widen)
i
(fsharp i)
in
if i = i' then i else iter (n+1) i'
in
let x = iter 0 env in
let y = fix fsharp x in (* decreasing iteration *)
E.join (condition (neg cond) y) u_u
| AST_HALT -> bottom_with_vars (E.vars env)
| AST_assert cond ->
if not
(E.is_bot (condition (neg cond) env))
then begin
Format.printf "%s: ERROR: assertion failure@."
(Abstract_syntax_printer.string_of_extent (snd stat));
end;
condition cond env
| AST_print items ->
Format.printf "%s: %s@."
(Abstract_syntax_printer.string_of_extent (snd stat))
(E.var_str env (List.map fst items));
env
| AST_local ((ty, _), vars) ->
List.fold_left
(fun env ((id, _), init) ->
let env2 = E.addvar env id in
match init with
| Some e -> E.assign env2 id e
| None -> env2)
env
vars
| _ -> assert false (* not implemented *)
end
let interpret prog =
let result = List.fold_left
(fun env x -> match x with
| AST_stat st -> interp_stmt env st
| _ -> env)
E.init
(fst prog)
in
Format.printf "Output: %s@."
(E.var_str result (E.vars result))
end