diff options
Diffstat (limited to 'minijazz/src/analysis/typing.ml')
-rw-r--r-- | minijazz/src/analysis/typing.ml | 375 |
1 files changed, 375 insertions, 0 deletions
diff --git a/minijazz/src/analysis/typing.ml b/minijazz/src/analysis/typing.ml new file mode 100644 index 0000000..0212b95 --- /dev/null +++ b/minijazz/src/analysis/typing.ml @@ -0,0 +1,375 @@ +open Ast +open Static +open Static_utils +open Printer +open Errors +open Misc +open Mapfold + +exception Unify + +type error_kind = + | Args_arity_error of int * int + | Params_arity_error of int * int + | Result_arity_error of int * int + | Type_error of ty * ty + | Static_type_error of static_ty * static_ty + | Static_constraint_false of static_exp + +exception Typing_error of error_kind + +let error k = raise (Typing_error k) + +let message loc err = match err with + | Args_arity_error (found, expected) -> + Format.eprintf "%aWrong number of arguments (found '%d'; expected '%d')@." + Location.print_location loc found expected + | Result_arity_error (found, expected) -> + Format.eprintf "%aWrong number of outputs (found '%d'; expected '%d')@." + Location.print_location loc found expected + | Params_arity_error (found, expected) -> + Format.eprintf "%aWrong number of static parameters (found '%d'; expected '%d')@." + Location.print_location loc found expected + | Type_error (found_ty, exp_ty) -> + Format.eprintf "%aThis expression has type '%a' but '%a' was expected@." + Location.print_location loc print_type found_ty print_type exp_ty + | Static_type_error (found_ty, exp_ty) -> + Format.eprintf "%aThis static expression has type '%a' but '%a' was expected@." + Location.print_location loc print_static_type found_ty print_static_type exp_ty + | Static_constraint_false se -> + Format.eprintf "%aThe following constraint is not satisfied: %a@." + Location.print_location loc print_static_exp se + + +type signature = + { s_inputs : ty list; + s_outputs : ty list; + s_params : name list; + s_constraints : static_exp list } + +module Modules = struct + let env = ref Ast.NameEnv.empty + + let add_sig ?(params = []) ?(constr = []) n inp outp = + let s = { s_inputs = inp; s_outputs = outp; s_params = params; s_constraints = constr } in + env := Ast.NameEnv.add n s !env + + let _ = + add_sig "and" [TBit;TBit] [TBit]; + add_sig "xor" [TBit;TBit] [TBit]; + add_sig "or" [TBit;TBit] [TBit]; + add_sig "not" [TBit] [TBit]; + add_sig "reg" [TBit] [TBit]; + add_sig "mux" [TBit;TBit;TBit] [TBit]; + add_sig ~params:["n"] "print" [TBitArray (mk_static_var "n"); TBit] [TBit]; + add_sig ~params:["n"] "input" [TBit] [TBitArray (mk_static_var "n")]; + let constr1 = mk_static_exp (SBinOp(SLess, mk_static_var "i", mk_static_var "n")) in + let constr2 = mk_static_exp (SBinOp(SLeq, mk_static_int 0, mk_static_var "i")) in + add_sig ~params:["i"; "n"] + ~constr:[constr1; constr2] + "select" [TBitArray (mk_static_var "n")] [TBit]; + let add = mk_static_exp (SBinOp(SAdd, mk_static_var "n1", mk_static_var "n2")) in + add_sig ~params:["n1"; "n2"; "n3"] + ~constr:[mk_static_exp (SBinOp (SEqual, mk_static_var "n3", add))] + "concat" [TBitArray (mk_static_var "n1"); TBitArray (mk_static_var "n2")] + [TBitArray (mk_static_var "n3")]; + (* slice : size = min <= max ? max - min + 1 : 0 *) + let size = + mk_static_exp + (SBinOp(SAdd, + mk_static_exp (SBinOp(SMinus, mk_static_var "max", mk_static_var "min")), + mk_static_int 1)) + in + let size = + mk_static_exp (SIf (mk_static_exp (SBinOp(SLeq, mk_static_var "min", mk_static_var "max")), + size, mk_static_int 0)) + in + let constr1 = mk_static_exp (SBinOp(SLeq, mk_static_int 0, mk_static_var "min")) in + let constr2 = mk_static_exp (SBinOp(SLess, mk_static_var "max", mk_static_var "n")) in + add_sig ~params:["min"; "max"; "n"] ~constr:[constr1; constr2] "slice" + [TBitArray (mk_static_var "n")] [TBitArray size] + + + let tys_of_vds vds = List.map (fun vd -> vd.v_ty) vds + + let add_node n constr = + let s = { s_inputs = tys_of_vds n.n_inputs; + s_outputs = tys_of_vds n.n_outputs; + s_params = List.map (fun p -> p.p_name) n.n_params; + s_constraints = constr } in + env := Ast.NameEnv.add n.n_name s !env + + let build_param_env param_names params = + List.fold_left2 + (fun env pn p -> NameEnv.add pn p env) + NameEnv.empty param_names params + + let subst_ty env ty = match ty with + | TBitArray se -> TBitArray (subst env se) + | _ -> ty + + let find_node n params = + try + let s = Ast.NameEnv.find n !env in + if List.length s.s_params <> List.length params then + error (Params_arity_error (List.length params, List.length s.s_params)); + let env = build_param_env s.s_params params in + let s = + { s with s_inputs = List.map (subst_ty env) s.s_inputs; + s_outputs = List.map (subst_ty env) s.s_outputs; + s_constraints = List.map (subst env) s.s_constraints } + in + s + with Not_found -> + Format.eprintf "Unbound node '%s'@." n; + raise Error +end + +let constr_list = ref [] +let add_constraint se = + constr_list := se :: !constr_list +let set_constraints cl = + constr_list := cl +let get_constraints () = + let v = !constr_list in + constr_list := []; v + +let fresh_static_var () = + SVar ("s_"^(Misc.gen_symbol ())) + +(* Functions on types*) + +let fresh_type = + let index = ref 0 in + let gen_index () = (incr index; !index) in + let fresh_type () = TVar (ref (TIndex (gen_index ()))) in + fresh_type + +(** returns the canonic (short) representant of [ty] + and update it to this value. *) +let rec ty_repr ty = match ty with + | TVar link -> + (match !link with + | TLink ty -> + let ty = ty_repr ty in + link := TLink ty; + ty + | _ -> ty) + | _ -> ty + +(** verifies that index is fresh in ck. *) +let rec occur_check index ty = + let ty = ty_repr ty in + match ty with + | TUnit | TBit | TBitArray _ -> () + | TVar { contents = TIndex n } when index <> n -> () + | TProd ty_list -> List.iter (occur_check index) ty_list + | _ -> raise Unify + +let rec unify ty1 ty2 = + let ty1 = ty_repr ty1 in + let ty2 = ty_repr ty2 in + if ty1 == ty2 then () + else + match (ty1, ty2) with + | TBitArray n, TBit | TBit, TBitArray n -> + add_constraint (mk_static_exp (SBinOp(SEqual, n, mk_static_int 1))) + | TBitArray n1, TBitArray n2 -> + add_constraint (mk_static_exp (SBinOp(SEqual, n1, n2))) + | TVar { contents = TIndex n1 }, TVar { contents = TIndex n2 } when n1 = n2 -> () + | TProd ty_list1, TProd ty_list2 -> + if List.length ty_list1 <> List.length ty_list2 then + error (Result_arity_error (List.length ty_list1, List.length ty_list2)); + List.iter2 unify ty_list1 ty_list2 + | TVar ({ contents = TIndex n } as link), ty + | ty, TVar ({ contents = TIndex n } as link) -> + occur_check n ty; + link := TLink ty + | _ -> raise Unify + +let prod ty_list = match ty_list with + | [ty] -> ty + | _ -> TProd ty_list + +(* Typing of static exps *) +let rec type_static_exp se = match se.se_desc with + | SInt _ | SVar _ -> STInt + | SBool _ -> STBool + | SBinOp((SAdd | SMinus | SMult | SDiv | SPower ), se1, se2) -> + expect_static_exp se1 STInt; + expect_static_exp se2 STInt; + STInt + | SBinOp((SEqual | SLess | SLeq | SGreater | SGeq), se1, se2) -> + expect_static_exp se1 STInt; + expect_static_exp se2 STInt; + STBool + | SIf (c, se1, se2) -> + expect_static_exp se1 STBool; + let ty1 = type_static_exp se1 in + expect_static_exp se2 ty1; + ty1 + +and expect_static_exp se ty = + let found_ty = type_static_exp se in + if found_ty <> ty then + error (Static_type_error (found_ty, ty)) + +let rec simplify_constr cl = match cl with + | [] -> [] + | c::cl -> + let c' = simplify NameEnv.empty c in + match c'.se_desc with + | SBool true -> simplify_constr cl + | SBool false -> error (Static_constraint_false c) + | _ -> c::(simplify_constr cl) + +let rec find_simplification_one c = match c.se_desc with + | SBinOp(SEqual, { se_desc = SVar s }, se) + | SBinOp(SEqual, se, { se_desc = SVar s }) -> + Some (s, se) + | SIf(_, se1, { se_desc = SBool true }) + | SIf(_, { se_desc = SBool true }, se1) -> + find_simplification_one se1 + | _ -> None + +let rec find_simplification params cl = match cl with + | [] -> None, [] + | c::cl -> + (match find_simplification_one c with + | Some (s, se) when not (List.mem s params) -> + Some (s, se), cl + | _ -> + let res, cl = find_simplification params cl in + res, c::cl) + +let solve_constr params cl = + let params = List.map (fun p -> p.p_name) params in + let subst_and_error env c = + let c' = subst env c in + match c'.se_desc with + | SBool false -> error (Static_constraint_false c) + | _ -> c' + in + let env = ref NameEnv.empty in + let rec solve_one cl = + let res, cl = find_simplification params cl in + match res with + | None -> cl + | Some (s, se) -> + env := NameEnv.add s se !env; + let cl = List.map (subst_and_error !env) cl in + solve_one cl + in + let cl = simplify_constr cl in + let cl = solve_one cl in + cl, !env + +(* Typing of expressions *) +let rec type_exp env e = + try + let desc, ty = match e.e_desc with + | Econst (VBit _) -> e.e_desc, TBit + | Econst (VBitArray a) -> e.e_desc, TBitArray (mk_static_int (Array.length a)) + | Evar id -> Evar id, IdentEnv.find id env + | Ereg e -> + let e = expect_exp env e TBit in + Ereg e, TBit + | Emem (MRom, addr_size, word_size, file, args) -> + (* addr_size > 0 *) + add_constraint (mk_static_exp (SBinOp (SLess, mk_static_int 0, addr_size))); + let read_addr = assert_1 args in + let read_addr = expect_exp env read_addr (TBitArray addr_size) in + Emem (MRom, addr_size, word_size, file, [read_addr]), TBitArray word_size + | Emem (MRam, addr_size, word_size, file, args) -> + (* addr_size > 0 *) + add_constraint (mk_static_exp (SBinOp (SLess, mk_static_int 0, addr_size))); + let read_addr, write_en, write_addr, data_in = assert_4 args in + let read_addr = expect_exp env read_addr (TBitArray addr_size) in + let write_addr = expect_exp env write_addr (TBitArray addr_size) in + let data_in = expect_exp env data_in (TBitArray word_size) in + let write_en = expect_exp env write_en TBit in + let args = [read_addr; write_en; write_addr; data_in] in + Emem (MRam, addr_size, word_size, file, args), TBitArray word_size + | Ecall (f, params, args) -> + let s = Modules.find_node f params in + (*check arity*) + if List.length s.s_inputs <> List.length args then + error (Args_arity_error (List.length args, List.length s.s_inputs)); + (*check types of all arguments*) + let args = List.map2 (expect_exp env) args s.s_inputs in + List.iter add_constraint s.s_constraints; + Ecall(f, params, args), prod s.s_outputs + in + { e with e_desc = desc; e_ty = ty }, ty + with + | Typing_error k -> message e.e_loc k; raise Error + +and expect_exp env e ty = + let e, found_ty = type_exp env e in + try + unify ty found_ty; + e + with + Unify -> error (Type_error (found_ty, ty)) + +let type_pat env pat = match pat with + | Evarpat x -> IdentEnv.find x env + | Etuplepat id_list -> prod (List.map (fun x -> IdentEnv.find x env) id_list) + +let type_eq env (pat, e) = + let pat_ty = type_pat env pat in + let e = expect_exp env e pat_ty in + (pat, e) + +let build env vds = + let build_one env vd = IdentEnv.add vd.v_ident vd.v_ty env in + List.fold_left build_one env vds + +let rec type_block env b = match b with + | BEqs(eqs, vds) -> + let vds = List.map (fun vd -> { vd with v_ty = fresh_type () }) vds in + let env = build env vds in + let eqs = List.map (type_eq env) eqs in + BEqs(eqs,vds) + | BIf(se, trueb, falseb) -> + expect_static_exp se STBool; + let prev_constr = get_constraints () in + let trueb = type_block env trueb in + let true_constr = + List.map (fun c -> mk_static_exp (SIf (se, c, mk_static_bool true))) (get_constraints ()) + in + let falseb = type_block env falseb in + let false_constr = + List.map (fun c -> mk_static_exp (SIf (se, mk_static_bool true, c))) (get_constraints ()) + in + set_constraints (prev_constr @ true_constr @ false_constr); + BIf(se, trueb, falseb) + +let ty_repr_block env b = + let static_exp funs acc se = simplify env se, acc in + let ty funs acc ty = + let ty = ty_repr ty in + (* go through types to substitute static exps *) + Mapfold.ty funs acc ty + in + let funs = { Mapfold.defaults with ty = ty; static_exp = static_exp } in + let b, _ = Mapfold.block_it funs () b in + b + +let node n = + try + Modules.add_node n []; + let env = build IdentEnv.empty n.n_inputs in + let env = build env n.n_outputs in + let body = type_block env n.n_body in + let constr = get_constraints () in + let constr, env = solve_constr n.n_params constr in + let body = ty_repr_block env body in + Modules.add_node n constr; + { n with n_body = body; n_constraints = constr } + with + Typing_error k -> message n.n_loc k; raise Error + +let program p = + let p_nodes = List.map node p.p_nodes in + { p with p_nodes = p_nodes } |