open Ast
open Formula
open Util

module type ENUM_ENVIRONMENT_DOMAIN2 = sig
    type t

    type item = string

    (* construction *)
    val top           : (id * item list) list -> t
    val bot           : (id * item list) list -> t
    val is_bot        : t -> bool

    (* variable management *)
    val vars          : t -> (id * item list) list

    val forgetvar     : t -> id -> t
    val forgetvars    : t -> id list -> t
    val project       : t -> id -> item list

    (* set-theoretic operations *)
    val join          : t -> t -> t   (* union *)
    val meet          : t -> t -> t   (* intersection *)

    val subset        : t -> t -> bool
    val eq            : t -> t -> bool

    (* abstract operations *)
    val apply_cons    : t -> enum_cons -> t
    val apply_cl      : t -> enum_cons list -> t
    val assign        : t -> (id * id) list -> t

    (* pretty-printing *)
    val print         : Format.formatter -> t -> unit

    exception Top

    type item = string

    type edd =
        | DBot
        | DTop
        | DChoice of int * id * (item * edd) list

    type t = {
      vars      : (id * item list) list;
      order     : (id, int) Hashtbl.t;
      root      : edd;

    (* Utility functions for memoization *)
    let key = function
        | DBot -> -1
        | DTop -> -2
        | DChoice(i, _, _) -> i
    let memo f =
        let memo = Hashtbl.create 12 in
        let rec ff v =
          try Hashtbl.find memo (key v)
          with Not_found ->
            let r = f ff v in
            Hashtbl.add memo (key v) r; r
        in ff
    let memo2 f =
        let memo = Hashtbl.create 12 in
        let rec ff v1 v2 =
          try Hashtbl.find memo (key v1, key v2)
          with Not_found ->
            let r = f ff v1 v2 in
            Hashtbl.add memo (key v1, key v2) r; r
        in ff
    let edd_node_eq = function
        | DBot, DBot -> true
        | DTop, DTop -> true
        | DChoice(i, _, _), DChoice(j, _, _) -> i = j
        | _ -> false

    let new_node_fun () =
        let nc = ref 0 in
        let node_memo = Hashtbl.create 12 in
        fun v l ->
          let _, x0 = List.hd l in
          if List.exists (fun (_, x) -> not (edd_node_eq (x, x0))) l
            then begin
              let k = (v, (fun (a, b) -> a, key b) l) in
              let n =
                try Hashtbl.find node_memo k
                with _ -> (incr nc; Hashtbl.add node_memo k !nc; !nc)
              DChoice(n, v, l)
            end else x0

    let rank v = function
        | DChoice(_, x, _) -> Hashtbl.find v.order x
        | _ -> 10000000

      print : Format.formatter -> t -> unit
    let print fmt v =
        let print_nodes = Queue.create () in
        let a = Hashtbl.create 12 in

        let node_pc = Hashtbl.create 12 in
        let f f_rec = function
          | DChoice(_, _, l) ->
              (fun (_, c) -> match c with
                  | DChoice(n, _, _) ->
                      begin try Hashtbl.add node_pc n (Hashtbl.find node_pc n + 1)
                      with Not_found -> Hashtbl.add node_pc n 1 end
                  | _ -> ())
            List.iter (fun (_, c) -> f_rec c) l
          | _ -> ()
        in memo f v.root;

        let rec print_n fmt = function
          | DBot -> Format.fprintf fmt "⊥";
          | DTop -> Format.fprintf fmt "⊤";
          | DChoice(_, v, l) ->
            match List.filter (fun (_, x) -> x <> DBot) l with
            | [(c, nn)] ->
              let aux fmt = function
                | DChoice(nn, _, _) as i when Hashtbl.find node_pc nn >= 2 ->
                  if Hashtbl.mem a nn then () else begin
                    Queue.push i print_nodes;
                    Hashtbl.add a nn ()
                  Format.fprintf fmt "n%d" nn
                | x -> print_n fmt x
              Format.fprintf fmt "%a = %s,@ %a" Formula_printer.print_id v c aux nn
            | _ ->
              Format.fprintf fmt "%a ? " Formula_printer.print_id v;
              let print_u fmt (c, i) =
                Format.fprintf fmt "%s → " c;
                match i with
                | DChoice(nn, v, l) ->
                  if Hashtbl.mem a nn then () else begin
                    Queue.push i print_nodes;
                    Hashtbl.add a nn ()
                  Format.fprintf fmt "n%d" nn
                | _ -> Format.fprintf fmt "%a" print_n i
              Format.fprintf fmt "@[<h>%a@]" (print_list print_u ", ") l;
        Format.fprintf fmt "@[<v 4>{ @[<hov>%a@]" print_n v.root;
        while not (Queue.is_empty print_nodes) do
          match Queue.pop print_nodes with
          | DChoice(n, v, l) as x ->
            Format.fprintf fmt "@ n%d: @[<hov>%a@]" n print_n x
          | _ -> assert false
        Format.fprintf fmt " }@]"

      top : (id * item list) list -> t
      bot : (id * item list) list -> t
    let top vars =
        let order = Hashtbl.create 12 in
        List.iteri (fun i (id, _) -> Hashtbl.add order id i) vars;
        { vars; order; root = DTop }
    let bot vars = let t = top vars in { t with root = DBot }

      is_bot : t -> bool
    let is_bot x = (x.root = DBot)

      vars : t -> (id * item list) list
    let vars x = x.vars
      of_cons : t -> enum_cons -> t
      The first t is NOT used as a decision function, here we only use
      the variable ordering it provides.
    let of_cons v0 (op, vid, r) =
        let op = match op with | E_EQ -> (=) | E_NE -> (<>) in

        let root = match r with
        | EItem x ->
          DChoice(0, vid,
     (fun v -> if op v x then v, DTop else v, DBot)
                (List.assoc vid v0.vars))      
        | EIdent vid2 ->
            let a, b =
                if Hashtbl.find v0.order vid < Hashtbl.find v0.order vid2
                  then vid, vid2
                  else vid2, vid
            let nc = ref 0 in
            let nb x =
              incr nc;
              DChoice(!nc, b,
           (fun v -> if op v x then v, DTop else v, DBot)
                      (List.assoc b v0.vars))
            DChoice(0, a, (fun x -> x, nb x) (List.assoc a v0.vars))
        { v0 with root = root }

      join : t -> t -> t
      meet : t -> t -> t
    let join a b =
        if a.root = DBot then b else
        if b.root = DBot then a else
        if a.root = DTop || b.root = DTop then { a with root = DTop } else begin
          let dq = new_node_fun () in

          let f f_rec na nb =
            match na, nb with
            | DChoice(_, va, la), DChoice(_, vb, lb) when va = vb ->
              let kl = List.map2
                  (fun (ta, ba) (tb, bb) -> assert (ta = tb);
                    ta, f_rec ba bb)
                  la lb
              dq va kl

            | DTop, _ | _, DTop -> DTop
            | DBot, DBot -> DBot

            | DChoice(_,va, la), _ when rank a na < rank a nb ->
              let kl = (fun (k, ca) -> k, f_rec ca nb) la in
              dq va kl
            | _, DChoice(_, vb, lb) when rank a nb < rank a na ->
              let kl = (fun (k, cb) -> k, f_rec na cb) lb in
              dq vb kl

            | _ -> assert false
            { a with root = memo2 f a.root b.root }

    let meet a b =
        if a.root = DTop then b else
        if b.root = DTop then a else
        if a.root = DBot || b.root = DBot then { a with root = DBot } else begin
          let dq = new_node_fun () in

          let f f_rec na nb =
            match na, nb with
            | DChoice(_, va, la), DChoice(_, vb, lb) when va = vb ->
              let kl = List.map2
                  (fun (ta, ba) (tb, bb) -> assert (ta = tb);
                    ta, f_rec ba bb)
                  la lb
              dq va kl

            | DBot, _ | _, DBot -> DBot
            | DTop, DTop -> DTop

            | DChoice(_, va, la), _ when rank a na < rank a nb ->
              let kl = (fun (k, ca) -> k, f_rec ca nb) la in
              dq va kl
            | _, DChoice(_, vb, lb) when rank a nb < rank a na ->
              let kl = (fun (k, cb) -> k, f_rec na cb) lb in
              dq vb kl

            | _ -> assert false
            {a with root = memo2 f a.root b.root }

      apply_cons : t -> enum_cons -> t
      apply_cl : t -> enum_cons list -> t
    let apply_cons v x =
      meet v (of_cons v x)

    let apply_cl v ec =
        let rec cl_k = function
          | [] -> { v with root = DTop }
          | [a] -> of_cons v a
          | l ->
            let n = ref 0 in
            let la, lb = List.partition (fun _ -> incr n; !n mod 2 = 0) l in
            meet (cl_k la) (cl_k lb)
        let cons_edd = cl_k ec in
        meet v cons_edd

      eq : edd_v -> edd_v -> bool
    let eq a b =
        let f f_rec na nb =
          match na, nb with
          | DBot, DBot -> true
          | DTop, DTop -> true
          | DChoice(_, va, la), DChoice(_, vb, lb) when va = vb ->
            List.for_all2 (fun (ca, na) (cb, nb) -> assert (ca = cb); f_rec na nb)
                la lb
          | _ -> false
        in memo2 f a.root b.root

      subset : edd_v -> edd_v -> bool
    let subset a b =
        let rank = rank a in
        let f f_rec na nb =
          match na, nb with
          | DBot, _ -> true
          | _, DTop -> true
          | DTop, DBot -> false

          | DChoice(_, va, la), DChoice(_, vb, lb) when va = vb ->
            List.for_all2 (fun (ca, na) (cb, nb) -> assert (ca = cb); f_rec na nb)
              la lb
          | DChoice(_, va, la), _ when rank na < rank nb ->
            List.for_all (fun (c, n) -> f_rec n nb) la
          | _, DChoice(_, vb, lb) when rank na > rank nb ->
            List.for_all (fun (c, n) -> f_rec na n) lb
          | _ -> assert false
        in memo2 f a.root b.root

      forgetvars : t -> id list -> t
    let forgetvars v vars =
        let dq = new_node_fun () in

        let memo = Hashtbl.create 12 in
        let rec f l =
            let kl = List.sort ( key l) in
            try Hashtbl.find memo kl
            with Not_found -> let r =
                let cn = List.fold_left
                  (fun cn node -> match node with
                    | DBot -> cn
                    | DTop -> raise Top
                    | DChoice (n, v, l) -> (n, v, l)::cn)
                  [] l in
                let cn = List.sort
                  (fun (n, v1, _) (n, v2, _) ->
                      (Hashtbl.find v.order v1) (Hashtbl.find v.order v2))
                  cn in
                if cn = [] then
                  let _, dv, cl = List.hd cn in
                  let d, nd = List.partition (fun (_, v, _) -> v = dv) cn in
                  let ch1 = (fun (a, b, c) -> DChoice(a, b, c)) nd in
                  if List.mem dv vars then
                    (* Do union of all branches branching from nodes on variable dv *)
                    let ch2 = List.flatten
                      ( (fun (_, _, c) -> snd c) d) in
                    f (ch1@ch2)
                    (* Keep disjunction on variable dv *)
                    let cc =
                      (fun (c, _) ->
                        let ch2 = (fun (_, _, cl) -> List.assoc c cl) d in
                        c, f (ch1@ch2))
                      cl in
                    dq dv cc
              with | Top -> DTop
            in Hashtbl.add memo kl r; r
        { v with root = f [v.root] }

    let forgetvar v x = forgetvars v [x]

      project : t -> id -> item list
    let project v x =
        let vals = ref [] in
        let f f_rec = function
          | DBot -> ()
          | DChoice(_, var, l) when
              Hashtbl.find v.order var < Hashtbl.find v.order x ->
            List.iter (fun (_, c) -> f_rec c) l
          | DChoice(_, var, l) when var = x ->
              (fun (v, l) ->
                if l <> DBot && not (List.mem v !vals) then vals := v::(!vals))
          | _ -> raise Top
        memo f v.root; !vals
      with Top -> List.assoc x v.vars

      assign : t -> (id * id) list -> t
    let assign v ids =
      let v = forgetvars v ( fst ids) in
      apply_cl v ( (fun (x, y) -> (E_EQ, x, EIdent y)) ids)
