summaryrefslogblamecommitdiff
path: root/frontend/ast_printer.ml
blob: 08f4fb48446b858a476d9fe1f6a1c035e33f2931 (plain) (tree)


























                                                                                                                                      






                                  
                                   





                  
                                    

                    

 










                                           

                              
                                                       




                                                        











                                                            


                              

           
                                














                                                         






                                                  
                          



                                                  








                                                          























                                                            

                                                       
                                                        

                                                  



                                              

                                          
                                       
 
                                

                                                                  
 
 



                         

















                                                                      


                                                         
                                 

                                              
                                    

                                                 
                                                













                                                                        
                                    
                             







                                                                        

                                                           
                             

                                        

                                                                                                                   

              





                                                           
 
                                         
                                                   
              

                                             

                            
 
                                           
                      
                               
                                    
                              






                                                     
open Ast
open Lexing


(* Locations *)

let string_of_position p =
  Printf.sprintf "%s:%i:%i" p.pos_fname p.pos_lnum (p.pos_cnum - p.pos_bol)
    
let string_of_extent (p,q) =
  if p.pos_fname = q.pos_fname then
    if p.pos_lnum = q.pos_lnum then
      if p.pos_cnum = q.pos_cnum then
        Printf.sprintf "%s:%i.%i" p.pos_fname p.pos_lnum (p.pos_cnum - p.pos_bol)
      else
        Printf.sprintf "%s:%i.%i-%i" p.pos_fname p.pos_lnum (p.pos_cnum - p.pos_bol) (q.pos_cnum - q.pos_bol)
    else
      Printf.sprintf "%s:%i.%i-%i.%i" p.pos_fname p.pos_lnum (p.pos_cnum - p.pos_bol) q.pos_lnum (q.pos_cnum - q.pos_bol)
  else
    Printf.sprintf "%s:%i.%i-%s:%i.%i" p.pos_fname p.pos_lnum (p.pos_cnum - p.pos_bol) q.pos_fname q.pos_lnum (q.pos_cnum - q.pos_bol)


(* Operators *)

let string_of_unary_op = function
  | AST_UPLUS -> "+"
  | AST_UMINUS -> "-"

let string_of_binary_op = function
  | AST_MUL -> "*"
  | AST_DIV -> "/"
  | AST_MOD -> "mod"
  | AST_PLUS -> "+"
  | AST_MINUS -> "-"
let string_of_binary_rel = function
  | AST_EQ -> "="
  | AST_NE -> "<>"
  | AST_LT -> "<"
  | AST_LE -> "<="
  | AST_GT -> ">"
  | AST_GE -> ">="
let string_of_binary_bool = function
  | AST_AND -> "and"
  | AST_OR -> "or"


let binary_op_precedence = function
  | AST_MUL| AST_DIV| AST_MOD-> 51
  | AST_PLUS  | AST_MINUS -> 50
let binary_rel_precedence = function
  | AST_EQ | AST_NE -> 41
  | AST_LT | AST_LE | AST_GT | AST_GE -> 40
let binary_bool_precedence = function
  | AST_AND -> 31
  | AST_OR -> 30
let arrow_precedence = 20
let if_precedence = 10

let expr_precedence = function
  | AST_unary (_, _) | AST_pre(_, _) | AST_not(_) -> 99
  | AST_binary(op, _, _) -> binary_op_precedence op
  | AST_binary_rel(r, _, _) -> binary_rel_precedence r
  | AST_binary_bool(r, _, _) -> binary_bool_precedence r
  | AST_arrow(_, _) -> arrow_precedence
  | AST_if(_, _, _) -> if_precedence
  | _ -> 100

(* utility *)

let print_list f sep fmt l =
  let rec aux = function
    | [] -> ()
    | [a] -> f fmt a
    | a::b -> f fmt a; Format.pp_print_string fmt sep; aux b
  in
  aux l

let print_id_ext fmt (i, _) =
  Format.pp_print_string fmt i

(* types *)

let rec string_of_typ = function
  | AST_TINT -> "int"
  | AST_TBOOL -> "bool"
  | AST_TREAL -> "real"

(* expressions *)

let print_id fmt v =
  Format.pp_print_string fmt v

let rec print_expr fmt e = 
  match e with
    
  | AST_unary (op,(e1,_)) ->
      Format.pp_print_string fmt (string_of_unary_op op);
      if expr_precedence e1 <= expr_precedence e
      then Format.fprintf fmt "(%a)" print_expr e1
      else Format.fprintf fmt "%a" print_expr e1
  | AST_not (e1,_) ->
      Format.pp_print_string fmt "not ";
      if expr_precedence e1 <= expr_precedence e
      then Format.fprintf fmt "(%a)" print_expr e1
      else Format.fprintf fmt "%a" print_expr e1
  | AST_pre ((e1,_), _) ->
      Format.pp_print_string fmt "pre ";
      if expr_precedence e1 <= expr_precedence e
      then Format.fprintf fmt "(%a)" print_expr e1
      else Format.fprintf fmt "%a" print_expr e1

  | AST_binary (op,(e1,_),(e2,_)) ->
      if expr_precedence e1 < expr_precedence e
      then Format.fprintf fmt "(%a) " print_expr e1
      else Format.fprintf fmt "%a " print_expr e1;
      Format.pp_print_string fmt (string_of_binary_op op);
      if expr_precedence e2 <= expr_precedence e
      then Format.fprintf fmt " (%a)" print_expr e2
      else Format.fprintf fmt " %a" print_expr e2
  | AST_binary_rel (op,(e1,_),(e2,_)) ->
      if expr_precedence e1 < expr_precedence e
      then Format.fprintf fmt "(%a) " print_expr e1
      else Format.fprintf fmt "%a " print_expr e1;
      Format.pp_print_string fmt (string_of_binary_rel op);
      if expr_precedence e2 <= expr_precedence e
      then Format.fprintf fmt " (%a)" print_expr e2
      else Format.fprintf fmt " %a" print_expr e2
  | AST_binary_bool (op,(e1,_),(e2,_)) ->
      if expr_precedence e1 < expr_precedence e
      then Format.fprintf fmt "(%a) " print_expr e1
      else Format.fprintf fmt "%a " print_expr e1;
      Format.pp_print_string fmt (string_of_binary_bool op);
      if expr_precedence e2 <= expr_precedence e
      then Format.fprintf fmt " (%a)" print_expr e2
      else Format.fprintf fmt " %a" print_expr e2
  | AST_arrow ((e1,_),(e2,_)) ->
      if expr_precedence e1 < expr_precedence e
      then Format.fprintf fmt "(%a) " print_expr e1
      else Format.fprintf fmt "%a " print_expr e1;
      Format.pp_print_string fmt "->";
      if expr_precedence e2 <= expr_precedence e
      then Format.fprintf fmt " (%a)" print_expr e2
      else Format.fprintf fmt " %a" print_expr e2

  | AST_int_const (i,_) -> Format.pp_print_string fmt i
  | AST_real_const (i,_) -> Format.pp_print_string fmt i
  | AST_bool_const b -> Format.pp_print_bool fmt b

  | AST_if((c,_), (t,_), (e,_)) ->
      Format.fprintf fmt
        "if %a then %a else %a"
        print_expr c print_expr t print_expr e
        
  | AST_identifier (v,_) -> print_id fmt v
  | AST_idconst (v,_) -> print_id fmt v

  | AST_instance ((i,_),l, _) ->
        Format.fprintf fmt "%a(%a)"
          print_id i (print_list print_expr ", ") (List.map fst l)


(* equations *)

let indent ind = ind^"  "

let rec print_vars ind fmt = function
  | [] -> ()
  | v ->
    Format.fprintf fmt "%svar" ind;
    List.iter (fun d -> Format.fprintf fmt " %a;" print_var_decl d) v;
    Format.fprintf fmt "@\n";

and print_var_decl fmt (pr, (i, _), ty) =
    Format.fprintf fmt "%s%s: %s"
      (if pr then "probe " else "")
      i
      (string_of_typ ty)

and print_body ind fmt body =
  Format.fprintf fmt "%slet@\n%a%stel@\n"
    ind (print_block ind) body ind

and print_eqn ind fmt = function
  | AST_assign (l,(e,_)) ->
      Format.fprintf fmt "%s%a = %a;@\n" 
        ind (print_list print_id_ext ", ") l print_expr e
  | AST_assume((i, _), (e, _)) ->
      Format.fprintf fmt "%sassume %s: %a;@\n"
          ind i print_expr e
  | AST_guarantee((i, _), (e, _)) ->
      Format.fprintf fmt "%sguarantee %s: %a;@\n"
          ind i print_expr e
  | AST_automaton a -> print_automaton ind fmt a
  | AST_activate a -> print_activate ind fmt a

and print_activate ind fmt (x, r) =
  Format.fprintf fmt "%sactivate@\n" ind;
  print_activate_if (indent ind) fmt x;
  Format.fprintf fmt "%sreturns %a;@\n" ind (print_list print_id ", ") r

and print_activate_if ind fmt = function
  | AST_activate_if((c, _), t, e) ->
    Format.fprintf fmt "%sif %a then@\n" ind print_expr c;
    print_activate_if (indent ind) fmt t;
    Format.fprintf fmt "%selse@\n" ind;
    print_activate_if (indent ind) fmt e
  | AST_activate_body(b) ->
    print_vars ind fmt b.act_locals;
    print_body ind fmt b.body

and print_automaton ind fmt (n, sts, r) =
  Format.fprintf fmt "%sautomaton %s@\n" ind n;
  List.iter (print_state (indent ind) fmt) sts;
  Format.fprintf fmt "%sreturns %a;@\n" ind (print_list print_id ", ") r

and print_state ind fmt (st, _) =
  Format.fprintf fmt "%s%sstate %s@\n"
    ind (if st.initial then "initial " else "") st.st_name;
  print_vars ind fmt st.st_locals;
  print_body ind fmt st.body;
  if st.until <> [] then begin
    Format.fprintf fmt "%suntil@\n" ind;
    List.iter (fun ((e, _),(s, _), reset) -> 
        Format.fprintf fmt "%sif %a %s %s;@\n" (indent ind) print_expr e (if reset then "restart" else "resume") s)
      st.until
  end

and print_block ind fmt b =
  List.iter (fun (bb,_) -> print_eqn (indent ind) fmt bb) b

(* declarations *)


and print_node_decl fmt (d : node_decl) =
    Format.fprintf fmt "node %s(%a) returns(%a)@\n"
      d.n_name
      (print_list print_var_decl "; ") d.args
      (print_list print_var_decl "; ") d.ret;
    print_vars "" fmt d.var;
    print_body "" fmt d.body

let print_const_decl fmt (d : const_decl) =
    Format.fprintf fmt
      "const %s: %s = %a@\n@\n"
      d.c_name (string_of_typ d.typ)
      print_expr (fst d.value)

let print_toplevel fmt = function
    | AST_node_decl (n, _) -> print_node_decl fmt n
    | AST_const_decl (c, _) -> print_const_decl fmt c

let print_prog fmt p =
    List.iter (print_toplevel fmt) p