open Ast open Static open Static_utils open Printer open Errors open Misc open Mapfold exception Unify type error_kind = | Args_arity_error of int * int | Params_arity_error of int * int | Result_arity_error of int * int | Type_error of ty * ty | Static_type_error of static_ty * static_ty | Static_constraint_false of static_exp exception Typing_error of error_kind let error k = raise (Typing_error k) let message loc err = match err with | Args_arity_error (found, expected) -> Format.eprintf "%aWrong number of arguments (found '%d'; expected '%d')@." Location.print_location loc found expected | Result_arity_error (found, expected) -> Format.eprintf "%aWrong number of outputs (found '%d'; expected '%d')@." Location.print_location loc found expected | Params_arity_error (found, expected) -> Format.eprintf "%aWrong number of static parameters (found '%d'; expected '%d')@." Location.print_location loc found expected | Type_error (found_ty, exp_ty) -> Format.eprintf "%aThis expression has type '%a' but '%a' was expected@." Location.print_location loc print_type found_ty print_type exp_ty | Static_type_error (found_ty, exp_ty) -> Format.eprintf "%aThis static expression has type '%a' but '%a' was expected@." Location.print_location loc print_static_type found_ty print_static_type exp_ty | Static_constraint_false se -> Format.eprintf "%aThe following constraint is not satisfied: %a@." Location.print_location loc print_static_exp se type signature = { s_inputs : ty list; s_outputs : ty list; s_params : name list; s_constraints : static_exp list } module Modules = struct let env = ref Ast.NameEnv.empty let add_sig ?(params = []) ?(constr = []) n inp outp = let s = { s_inputs = inp; s_outputs = outp; s_params = params; s_constraints = constr } in env := Ast.NameEnv.add n s !env let _ = add_sig "and" [TBit;TBit] [TBit]; add_sig "xor" [TBit;TBit] [TBit]; add_sig "or" [TBit;TBit] [TBit]; add_sig "not" [TBit] [TBit]; add_sig "reg" [TBit] [TBit]; add_sig "mux" [TBit;TBit;TBit] [TBit]; add_sig ~params:["n"] "print" [TBitArray (mk_static_var "n"); TBit] [TBit]; add_sig ~params:["n"] "input" [TBit] [TBitArray (mk_static_var "n")]; let constr1 = mk_static_exp (SBinOp(SLess, mk_static_var "i", mk_static_var "n")) in let constr2 = mk_static_exp (SBinOp(SLeq, mk_static_int 0, mk_static_var "i")) in add_sig ~params:["i"; "n"] ~constr:[constr1; constr2] "select" [TBitArray (mk_static_var "n")] [TBit]; let add = mk_static_exp (SBinOp(SAdd, mk_static_var "n1", mk_static_var "n2")) in add_sig ~params:["n1"; "n2"; "n3"] ~constr:[mk_static_exp (SBinOp (SEqual, mk_static_var "n3", add))] "concat" [TBitArray (mk_static_var "n1"); TBitArray (mk_static_var "n2")] [TBitArray (mk_static_var "n3")]; (* slice : size = min <= max ? max - min + 1 : 0 *) let size = mk_static_exp (SBinOp(SAdd, mk_static_exp (SBinOp(SMinus, mk_static_var "max", mk_static_var "min")), mk_static_int 1)) in let size = mk_static_exp (SIf (mk_static_exp (SBinOp(SLeq, mk_static_var "min", mk_static_var "max")), size, mk_static_int 0)) in let constr1 = mk_static_exp (SBinOp(SLeq, mk_static_int 0, mk_static_var "min")) in let constr2 = mk_static_exp (SBinOp(SLess, mk_static_var "max", mk_static_var "n")) in add_sig ~params:["min"; "max"; "n"] ~constr:[constr1; constr2] "slice" [TBitArray (mk_static_var "n")] [TBitArray size] let tys_of_vds vds = List.map (fun vd -> vd.v_ty) vds let add_node n constr = let s = { s_inputs = tys_of_vds n.n_inputs; s_outputs = tys_of_vds n.n_outputs; s_params = List.map (fun p -> p.p_name) n.n_params; s_constraints = constr } in env := Ast.NameEnv.add n.n_name s !env let build_param_env param_names params = List.fold_left2 (fun env pn p -> NameEnv.add pn p env) NameEnv.empty param_names params let subst_ty env ty = match ty with | TBitArray se -> TBitArray (subst env se) | _ -> ty let find_node n params = try let s = Ast.NameEnv.find n !env in if List.length s.s_params <> List.length params then error (Params_arity_error (List.length params, List.length s.s_params)); let env = build_param_env s.s_params params in let s = { s with s_inputs = List.map (subst_ty env) s.s_inputs; s_outputs = List.map (subst_ty env) s.s_outputs; s_constraints = List.map (subst env) s.s_constraints } in s with Not_found -> Format.eprintf "Unbound node '%s'@." n; raise Error end let constr_list = ref [] let add_constraint se = constr_list := se :: !constr_list let set_constraints cl = constr_list := cl let get_constraints () = let v = !constr_list in constr_list := []; v let fresh_static_var () = SVar ("s_"^(Misc.gen_symbol ())) (* Functions on types*) let fresh_type = let index = ref 0 in let gen_index () = (incr index; !index) in let fresh_type () = TVar (ref (TIndex (gen_index ()))) in fresh_type (** returns the canonic (short) representant of [ty] and update it to this value. *) let rec ty_repr ty = match ty with | TVar link -> (match !link with | TLink ty -> let ty = ty_repr ty in link := TLink ty; ty | _ -> ty) | _ -> ty (** verifies that index is fresh in ck. *) let rec occur_check index ty = let ty = ty_repr ty in match ty with | TUnit | TBit | TBitArray _ -> () | TVar { contents = TIndex n } when index <> n -> () | TProd ty_list -> List.iter (occur_check index) ty_list | _ -> raise Unify let rec unify ty1 ty2 = let ty1 = ty_repr ty1 in let ty2 = ty_repr ty2 in if ty1 == ty2 then () else match (ty1, ty2) with | TBitArray n, TBit | TBit, TBitArray n -> add_constraint (mk_static_exp (SBinOp(SEqual, n, mk_static_int 1))) | TBitArray n1, TBitArray n2 -> add_constraint (mk_static_exp (SBinOp(SEqual, n1, n2))) | TVar { contents = TIndex n1 }, TVar { contents = TIndex n2 } when n1 = n2 -> () | TProd ty_list1, TProd ty_list2 -> if List.length ty_list1 <> List.length ty_list2 then error (Result_arity_error (List.length ty_list1, List.length ty_list2)); List.iter2 unify ty_list1 ty_list2 | TVar ({ contents = TIndex n } as link), ty | ty, TVar ({ contents = TIndex n } as link) -> occur_check n ty; link := TLink ty | _ -> raise Unify let prod ty_list = match ty_list with | [ty] -> ty | _ -> TProd ty_list (* Typing of static exps *) let rec type_static_exp se = match se.se_desc with | SInt _ | SVar _ -> STInt | SBool _ -> STBool | SBinOp((SAdd | SMinus | SMult | SDiv | SPower ), se1, se2) -> expect_static_exp se1 STInt; expect_static_exp se2 STInt; STInt | SBinOp((SEqual | SLess | SLeq | SGreater | SGeq), se1, se2) -> expect_static_exp se1 STInt; expect_static_exp se2 STInt; STBool | SIf (c, se1, se2) -> expect_static_exp se1 STBool; let ty1 = type_static_exp se1 in expect_static_exp se2 ty1; ty1 and expect_static_exp se ty = let found_ty = type_static_exp se in if found_ty <> ty then error (Static_type_error (found_ty, ty)) let rec simplify_constr cl = match cl with | [] -> [] | c::cl -> let c' = simplify NameEnv.empty c in match c'.se_desc with | SBool true -> simplify_constr cl | SBool false -> error (Static_constraint_false c) | _ -> c::(simplify_constr cl) let rec find_simplification_one c = match c.se_desc with | SBinOp(SEqual, { se_desc = SVar s }, se) | SBinOp(SEqual, se, { se_desc = SVar s }) -> Some (s, se) | SIf(_, se1, { se_desc = SBool true }) | SIf(_, { se_desc = SBool true }, se1) -> find_simplification_one se1 | _ -> None let rec find_simplification params cl = match cl with | [] -> None, [] | c::cl -> (match find_simplification_one c with | Some (s, se) when not (List.mem s params) -> Some (s, se), cl | _ -> let res, cl = find_simplification params cl in res, c::cl) let solve_constr params cl = let params = List.map (fun p -> p.p_name) params in let subst_and_error env c = let c' = subst env c in match c'.se_desc with | SBool false -> error (Static_constraint_false c) | _ -> c' in let env = ref NameEnv.empty in let rec solve_one cl = let res, cl = find_simplification params cl in match res with | None -> cl | Some (s, se) -> env := NameEnv.add s se !env; let cl = List.map (subst_and_error !env) cl in solve_one cl in let cl = simplify_constr cl in let cl = solve_one cl in cl, !env (* Typing of expressions *) let rec type_exp env e = try let desc, ty = match e.e_desc with | Econst (VBit _) -> e.e_desc, TBit | Econst (VBitArray a) -> e.e_desc, TBitArray (mk_static_int (Array.length a)) | Evar id -> Evar id, IdentEnv.find id env | Ereg e -> let e = expect_exp env e TBit in Ereg e, TBit | Emem (MRom, addr_size, word_size, file, args) -> (* addr_size > 0 *) add_constraint (mk_static_exp (SBinOp (SLess, mk_static_int 0, addr_size))); let read_addr = assert_1 args in let read_addr = expect_exp env read_addr (TBitArray addr_size) in Emem (MRom, addr_size, word_size, file, [read_addr]), TBitArray word_size | Emem (MRam, addr_size, word_size, file, args) -> (* addr_size > 0 *) add_constraint (mk_static_exp (SBinOp (SLess, mk_static_int 0, addr_size))); let read_addr, write_en, write_addr, data_in = assert_4 args in let read_addr = expect_exp env read_addr (TBitArray addr_size) in let write_addr = expect_exp env write_addr (TBitArray addr_size) in let data_in = expect_exp env data_in (TBitArray word_size) in let write_en = expect_exp env write_en TBit in let args = [read_addr; write_en; write_addr; data_in] in Emem (MRam, addr_size, word_size, file, args), TBitArray word_size | Ecall (f, params, args) -> let s = Modules.find_node f params in (*check arity*) if List.length s.s_inputs <> List.length args then error (Args_arity_error (List.length args, List.length s.s_inputs)); (*check types of all arguments*) let args = List.map2 (expect_exp env) args s.s_inputs in List.iter add_constraint s.s_constraints; Ecall(f, params, args), prod s.s_outputs in { e with e_desc = desc; e_ty = ty }, ty with | Typing_error k -> message e.e_loc k; raise Error and expect_exp env e ty = let e, found_ty = type_exp env e in try unify ty found_ty; e with Unify -> error (Type_error (found_ty, ty)) let type_pat env pat = match pat with | Evarpat x -> IdentEnv.find x env | Etuplepat id_list -> prod (List.map (fun x -> IdentEnv.find x env) id_list) let type_eq env (pat, e) = let pat_ty = type_pat env pat in let e = expect_exp env e pat_ty in (pat, e) let build env vds = let build_one env vd = IdentEnv.add vd.v_ident vd.v_ty env in List.fold_left build_one env vds let rec type_block env b = match b with | BEqs(eqs, vds) -> let vds = List.map (fun vd -> { vd with v_ty = fresh_type () }) vds in let env = build env vds in let eqs = List.map (type_eq env) eqs in BEqs(eqs,vds) | BIf(se, trueb, falseb) -> expect_static_exp se STBool; let prev_constr = get_constraints () in let trueb = type_block env trueb in let true_constr = List.map (fun c -> mk_static_exp (SIf (se, c, mk_static_bool true))) (get_constraints ()) in let falseb = type_block env falseb in let false_constr = List.map (fun c -> mk_static_exp (SIf (se, mk_static_bool true, c))) (get_constraints ()) in set_constraints (prev_constr @ true_constr @ false_constr); BIf(se, trueb, falseb) let ty_repr_block env b = let static_exp funs acc se = simplify env se, acc in let ty funs acc ty = let ty = ty_repr ty in (* go through types to substitute static exps *) Mapfold.ty funs acc ty in let funs = { Mapfold.defaults with ty = ty; static_exp = static_exp } in let b, _ = Mapfold.block_it funs () b in b let node n = try Modules.add_node n []; let env = build IdentEnv.empty n.n_inputs in let env = build env n.n_outputs in let body = type_block env n.n_body in let constr = get_constraints () in let constr, env = solve_constr n.n_params constr in let body = ty_repr_block env body in Modules.add_node n constr; { n with n_body = body; n_constraints = constr } with Typing_error k -> message n.n_loc k; raise Error let program p = let p_nodes = List.map node p.p_nodes in { p with p_nodes = p_nodes }