summaryrefslogblamecommitdiff
path: root/frontend/ast_printer.ml
blob: 09f4986abfd8ca70a7f29545f40e588edaaa7b2b (plain) (tree)
1
2
3
4
5
6
7
8

           
           
         
 


                                             






                                 






                                  
                                   





                  
                                    

                    

 
                         






                                           

                 

                         

                              
                                                                     




                                                        



             


                              

           
                                














                                                         






                                                  
                          



                                                  


                                           








                                                          








                                                           
                                                













                                                            

                                                       
                                                        

                                                  



                                              

                                          
                                       
 
                                

                                                                  
 



                                                     
 



                         




                                      


                                                                      
                        
 
                                    








                                         


                                                           
                                


                                                         
                                 

                                              
                                    

                                                 
                                                













                                                                        
                                              







                                                                        
                                                           

                                              

                                        

                                                                                                                   

              
 

                  
 
                                         
                                                   
              

                                             
                                      
 
                                           
                      
                               
                                    
                              






                                                     






                                        
                                                                                     


                                                                              
open Ast
open Lexing
open Typing
open Util

(*
  Just a pretty-printer, nothing to see here.
*)


(* Operators *)

let string_of_unary_op = function
  | AST_UPLUS -> "+"
  | AST_UMINUS -> "-"

let string_of_binary_op = function
  | AST_MUL -> "*"
  | AST_DIV -> "/"
  | AST_MOD -> "mod"
  | AST_PLUS -> "+"
  | AST_MINUS -> "-"
let string_of_binary_rel = function
  | AST_EQ -> "="
  | AST_NE -> "<>"
  | AST_LT -> "<"
  | AST_LE -> "<="
  | AST_GT -> ">"
  | AST_GE -> ">="
let string_of_binary_bool = function
  | AST_AND -> "and"
  | AST_OR -> "or"


let unary_precedence = 99
let binary_op_precedence = function
  | AST_MUL| AST_DIV| AST_MOD-> 51
  | AST_PLUS  | AST_MINUS -> 50
let binary_rel_precedence = function
  | AST_EQ | AST_NE -> 41
  | AST_LT | AST_LE | AST_GT | AST_GE -> 40
let binary_bool_precedence = function
  | AST_OR -> 31
  | AST_AND -> 30
let arrow_precedence = 20
let if_precedence = 10

let expr_precedence = function
  | AST_unary (_, _) | AST_pre(_, _) | AST_not(_) -> unary_precedence
  | AST_binary(op, _, _) -> binary_op_precedence op
  | AST_binary_rel(r, _, _) -> binary_rel_precedence r
  | AST_binary_bool(r, _, _) -> binary_bool_precedence r
  | AST_arrow(_, _) -> arrow_precedence
  | AST_if(_, _, _) -> if_precedence
  | _ -> 100

(* utility *)

let print_id_ext fmt (i, _) =
  Format.pp_print_string fmt i

(* types *)

let rec string_of_typ = function
  | AST_TINT -> "int"
  | AST_TBOOL -> "bool"
  | AST_TREAL -> "real"

(* expressions *)

let print_id fmt v =
  Format.pp_print_string fmt v

let rec print_expr fmt e = 
  match e with
    
  | AST_unary (op,(e1,_)) ->
      Format.pp_print_string fmt (string_of_unary_op op);
      if expr_precedence e1 <= expr_precedence e
      then Format.fprintf fmt "(%a)" print_expr e1
      else Format.fprintf fmt "%a" print_expr e1
  | AST_not (e1,_) ->
      Format.pp_print_string fmt "not ";
      if expr_precedence e1 <= expr_precedence e
      then Format.fprintf fmt "(%a)" print_expr e1
      else Format.fprintf fmt "%a" print_expr e1
  | AST_pre ((e1,_), _) ->
      Format.pp_print_string fmt "pre ";
      if expr_precedence e1 <= expr_precedence e
      then Format.fprintf fmt "(%a)" print_expr e1
      else Format.fprintf fmt "%a" print_expr e1
  | AST_cast ((e1,_), ty) ->
      Format.fprintf fmt "%s (%a)"
          (string_of_typ ty) print_expr e1;

  | AST_binary (op,(e1,_),(e2,_)) ->
      if expr_precedence e1 < expr_precedence e
      then Format.fprintf fmt "(%a) " print_expr e1
      else Format.fprintf fmt "%a " print_expr e1;
      Format.pp_print_string fmt (string_of_binary_op op);
      if expr_precedence e2 <= expr_precedence e
      then Format.fprintf fmt " (%a)" print_expr e2
      else Format.fprintf fmt " %a" print_expr e2
  | AST_binary_rel (op,(e1,_),(e2,_)) ->
      if expr_precedence e1 < expr_precedence e
      then Format.fprintf fmt "(%a) " print_expr e1
      else Format.fprintf fmt "%a " print_expr e1;
      Format.pp_print_string fmt (string_of_binary_rel op);
      if expr_precedence e2 <= expr_precedence e
      then Format.fprintf fmt " (%a)" print_expr e2
      else Format.fprintf fmt " %a" print_expr e2
  | AST_binary_bool (op,(e1,_),(e2,_)) ->
      if expr_precedence e1 <= expr_precedence e
      then Format.fprintf fmt "(%a) " print_expr e1
      else Format.fprintf fmt "%a " print_expr e1;
      Format.pp_print_string fmt (string_of_binary_bool op);
      if expr_precedence e2 <= expr_precedence e
      then Format.fprintf fmt " (%a)" print_expr e2
      else Format.fprintf fmt " %a" print_expr e2
  | AST_arrow ((e1,_),(e2,_)) ->
      if expr_precedence e1 < expr_precedence e
      then Format.fprintf fmt "(%a) " print_expr e1
      else Format.fprintf fmt "%a " print_expr e1;
      Format.pp_print_string fmt "->";
      if expr_precedence e2 <= expr_precedence e
      then Format.fprintf fmt " (%a)" print_expr e2
      else Format.fprintf fmt " %a" print_expr e2

  | AST_int_const (i,_) -> Format.pp_print_string fmt i
  | AST_real_const (i,_) -> Format.pp_print_string fmt i
  | AST_bool_const b -> Format.pp_print_bool fmt b

  | AST_if((c,_), (t,_), (e,_)) ->
      Format.fprintf fmt
        "if %a then %a else %a"
        print_expr c print_expr t print_expr e
        
  | AST_identifier (v,_) -> print_id fmt v
  | AST_idconst (v,_) -> print_id fmt v

  | AST_instance ((i,_),l, _) ->
        Format.fprintf fmt "%a(%a)"
          print_id i (print_list print_expr ", ") (List.map fst l)

  | AST_tuple x ->
      Format.fprintf fmt "(%a)"
        (print_list print_expr ", ") (List.map fst x)


(* equations *)

let indent ind = ind^"  "

let rec print_scope ind fmt = function
  | [], [a, _] ->
    print_eqn ind fmt a
  | [], l -> print_body ind fmt l
  | v, l ->
    Format.fprintf fmt "%svar" ind;
    List.iter (fun d -> Format.fprintf fmt " %a;" print_var_decl d) v;
    Format.fprintf fmt "@\n";
    print_body ind fmt l

and print_var_decl fmt (pr, i, ty) =
    Format.fprintf fmt "%s%s: %s"
      (if pr then "probe " else "")
      i
      (string_of_typ ty)

and print_body ind fmt body =
  Format.fprintf fmt "%slet@\n%a%stel@\n"
    ind (print_block ind) body ind

and print_block ind fmt b =
  List.iter (fun (bb,_) -> print_eqn (indent ind) fmt bb) b

and print_eqn ind fmt = function
  | AST_assign (l,(e,_)) ->
      Format.fprintf fmt "%s%a = %a;@\n" 
        ind (print_list print_id_ext ", ") l print_expr e
  | AST_assume((i, _), (e, _)) ->
      Format.fprintf fmt "%sassume %s: %a;@\n"
          ind i print_expr e
  | AST_guarantee((i, _), (e, _)) ->
      Format.fprintf fmt "%sguarantee %s: %a;@\n"
          ind i print_expr e
  | AST_automaton a -> print_automaton ind fmt a
  | AST_activate a -> print_activate ind fmt a

and print_activate ind fmt (x, r) =
  Format.fprintf fmt "%sactivate@\n" ind;
  print_activate_if (indent ind) fmt x;
  Format.fprintf fmt "%sreturns %a;@\n" ind (print_list print_id ", ") r

and print_activate_if ind fmt = function
  | AST_activate_if((c, _), t, e) ->
    Format.fprintf fmt "%sif %a then@\n" ind print_expr c;
    print_activate_if (indent ind) fmt t;
    Format.fprintf fmt "%selse@\n" ind;
    print_activate_if (indent ind) fmt e
  | AST_activate_body(b) ->
    print_scope ind fmt (b.act_locals, b.body)

and print_automaton ind fmt (n, sts, r) =
  Format.fprintf fmt "%sautomaton %s@\n" ind n;
  List.iter (print_state (indent ind) fmt) sts;
  Format.fprintf fmt "%sreturns %a;@\n" ind (print_list print_id ", ") r

and print_state ind fmt (st, _) =
  Format.fprintf fmt "%s%sstate %s@\n"
    ind (if st.initial then "initial " else "") st.st_name;
  let ind = indent ind in
  print_scope ind fmt (st.st_locals, st.body);
  if st.until <> [] then begin
    Format.fprintf fmt "%suntil@\n" ind;
    List.iter (fun ((e, _),(s, _), reset) -> 
        Format.fprintf fmt "%sif %a %s %s;@\n" (indent ind) print_expr e (if reset then "restart" else "resume") s)
      st.until
  end

(* declarations *)


and print_node_decl fmt (d : node_decl) =
    Format.fprintf fmt "node %s(%a) returns(%a)@\n"
      d.n_name
      (print_list print_var_decl "; ") d.args
      (print_list print_var_decl "; ") d.ret;
    print_scope "" fmt (d.var, d.body)

let print_const_decl fmt (d : const_decl) =
    Format.fprintf fmt
      "const %s: %s = %a@\n@\n"
      d.c_name (string_of_typ d.typ)
      print_expr (fst d.value)

let print_toplevel fmt = function
    | AST_node_decl (n, _) -> print_node_decl fmt n
    | AST_const_decl (c, _) -> print_const_decl fmt c

let print_prog fmt p =
    List.iter (print_toplevel fmt) p


(* Typed variable *)

let print_type fmt = function
    | TInt -> Format.fprintf fmt "int"
    | TReal -> Format.fprintf fmt "real"
    | TEnum e -> Format.fprintf fmt "enum @[<h>{ %a }@]" (print_list print_id ", ") e

let print_typed_var fmt (p, id, t) =
    Format.fprintf fmt "%s%s: %a" (if p then "probe " else "") id print_type t