summaryrefslogtreecommitdiff
path: root/minijazz/src/analysis/typing.ml
diff options
context:
space:
mode:
Diffstat (limited to 'minijazz/src/analysis/typing.ml')
-rw-r--r--minijazz/src/analysis/typing.ml375
1 files changed, 375 insertions, 0 deletions
diff --git a/minijazz/src/analysis/typing.ml b/minijazz/src/analysis/typing.ml
new file mode 100644
index 0000000..0212b95
--- /dev/null
+++ b/minijazz/src/analysis/typing.ml
@@ -0,0 +1,375 @@
+open Ast
+open Static
+open Static_utils
+open Printer
+open Errors
+open Misc
+open Mapfold
+
+exception Unify
+
+type error_kind =
+ | Args_arity_error of int * int
+ | Params_arity_error of int * int
+ | Result_arity_error of int * int
+ | Type_error of ty * ty
+ | Static_type_error of static_ty * static_ty
+ | Static_constraint_false of static_exp
+
+exception Typing_error of error_kind
+
+let error k = raise (Typing_error k)
+
+let message loc err = match err with
+ | Args_arity_error (found, expected) ->
+ Format.eprintf "%aWrong number of arguments (found '%d'; expected '%d')@."
+ Location.print_location loc found expected
+ | Result_arity_error (found, expected) ->
+ Format.eprintf "%aWrong number of outputs (found '%d'; expected '%d')@."
+ Location.print_location loc found expected
+ | Params_arity_error (found, expected) ->
+ Format.eprintf "%aWrong number of static parameters (found '%d'; expected '%d')@."
+ Location.print_location loc found expected
+ | Type_error (found_ty, exp_ty) ->
+ Format.eprintf "%aThis expression has type '%a' but '%a' was expected@."
+ Location.print_location loc print_type found_ty print_type exp_ty
+ | Static_type_error (found_ty, exp_ty) ->
+ Format.eprintf "%aThis static expression has type '%a' but '%a' was expected@."
+ Location.print_location loc print_static_type found_ty print_static_type exp_ty
+ | Static_constraint_false se ->
+ Format.eprintf "%aThe following constraint is not satisfied: %a@."
+ Location.print_location loc print_static_exp se
+
+
+type signature =
+ { s_inputs : ty list;
+ s_outputs : ty list;
+ s_params : name list;
+ s_constraints : static_exp list }
+
+module Modules = struct
+ let env = ref Ast.NameEnv.empty
+
+ let add_sig ?(params = []) ?(constr = []) n inp outp =
+ let s = { s_inputs = inp; s_outputs = outp; s_params = params; s_constraints = constr } in
+ env := Ast.NameEnv.add n s !env
+
+ let _ =
+ add_sig "and" [TBit;TBit] [TBit];
+ add_sig "xor" [TBit;TBit] [TBit];
+ add_sig "or" [TBit;TBit] [TBit];
+ add_sig "not" [TBit] [TBit];
+ add_sig "reg" [TBit] [TBit];
+ add_sig "mux" [TBit;TBit;TBit] [TBit];
+ add_sig ~params:["n"] "print" [TBitArray (mk_static_var "n"); TBit] [TBit];
+ add_sig ~params:["n"] "input" [TBit] [TBitArray (mk_static_var "n")];
+ let constr1 = mk_static_exp (SBinOp(SLess, mk_static_var "i", mk_static_var "n")) in
+ let constr2 = mk_static_exp (SBinOp(SLeq, mk_static_int 0, mk_static_var "i")) in
+ add_sig ~params:["i"; "n"]
+ ~constr:[constr1; constr2]
+ "select" [TBitArray (mk_static_var "n")] [TBit];
+ let add = mk_static_exp (SBinOp(SAdd, mk_static_var "n1", mk_static_var "n2")) in
+ add_sig ~params:["n1"; "n2"; "n3"]
+ ~constr:[mk_static_exp (SBinOp (SEqual, mk_static_var "n3", add))]
+ "concat" [TBitArray (mk_static_var "n1"); TBitArray (mk_static_var "n2")]
+ [TBitArray (mk_static_var "n3")];
+ (* slice : size = min <= max ? max - min + 1 : 0 *)
+ let size =
+ mk_static_exp
+ (SBinOp(SAdd,
+ mk_static_exp (SBinOp(SMinus, mk_static_var "max", mk_static_var "min")),
+ mk_static_int 1))
+ in
+ let size =
+ mk_static_exp (SIf (mk_static_exp (SBinOp(SLeq, mk_static_var "min", mk_static_var "max")),
+ size, mk_static_int 0))
+ in
+ let constr1 = mk_static_exp (SBinOp(SLeq, mk_static_int 0, mk_static_var "min")) in
+ let constr2 = mk_static_exp (SBinOp(SLess, mk_static_var "max", mk_static_var "n")) in
+ add_sig ~params:["min"; "max"; "n"] ~constr:[constr1; constr2] "slice"
+ [TBitArray (mk_static_var "n")] [TBitArray size]
+
+
+ let tys_of_vds vds = List.map (fun vd -> vd.v_ty) vds
+
+ let add_node n constr =
+ let s = { s_inputs = tys_of_vds n.n_inputs;
+ s_outputs = tys_of_vds n.n_outputs;
+ s_params = List.map (fun p -> p.p_name) n.n_params;
+ s_constraints = constr } in
+ env := Ast.NameEnv.add n.n_name s !env
+
+ let build_param_env param_names params =
+ List.fold_left2
+ (fun env pn p -> NameEnv.add pn p env)
+ NameEnv.empty param_names params
+
+ let subst_ty env ty = match ty with
+ | TBitArray se -> TBitArray (subst env se)
+ | _ -> ty
+
+ let find_node n params =
+ try
+ let s = Ast.NameEnv.find n !env in
+ if List.length s.s_params <> List.length params then
+ error (Params_arity_error (List.length params, List.length s.s_params));
+ let env = build_param_env s.s_params params in
+ let s =
+ { s with s_inputs = List.map (subst_ty env) s.s_inputs;
+ s_outputs = List.map (subst_ty env) s.s_outputs;
+ s_constraints = List.map (subst env) s.s_constraints }
+ in
+ s
+ with Not_found ->
+ Format.eprintf "Unbound node '%s'@." n;
+ raise Error
+end
+
+let constr_list = ref []
+let add_constraint se =
+ constr_list := se :: !constr_list
+let set_constraints cl =
+ constr_list := cl
+let get_constraints () =
+ let v = !constr_list in
+ constr_list := []; v
+
+let fresh_static_var () =
+ SVar ("s_"^(Misc.gen_symbol ()))
+
+(* Functions on types*)
+
+let fresh_type =
+ let index = ref 0 in
+ let gen_index () = (incr index; !index) in
+ let fresh_type () = TVar (ref (TIndex (gen_index ()))) in
+ fresh_type
+
+(** returns the canonic (short) representant of [ty]
+ and update it to this value. *)
+let rec ty_repr ty = match ty with
+ | TVar link ->
+ (match !link with
+ | TLink ty ->
+ let ty = ty_repr ty in
+ link := TLink ty;
+ ty
+ | _ -> ty)
+ | _ -> ty
+
+(** verifies that index is fresh in ck. *)
+let rec occur_check index ty =
+ let ty = ty_repr ty in
+ match ty with
+ | TUnit | TBit | TBitArray _ -> ()
+ | TVar { contents = TIndex n } when index <> n -> ()
+ | TProd ty_list -> List.iter (occur_check index) ty_list
+ | _ -> raise Unify
+
+let rec unify ty1 ty2 =
+ let ty1 = ty_repr ty1 in
+ let ty2 = ty_repr ty2 in
+ if ty1 == ty2 then ()
+ else
+ match (ty1, ty2) with
+ | TBitArray n, TBit | TBit, TBitArray n ->
+ add_constraint (mk_static_exp (SBinOp(SEqual, n, mk_static_int 1)))
+ | TBitArray n1, TBitArray n2 ->
+ add_constraint (mk_static_exp (SBinOp(SEqual, n1, n2)))
+ | TVar { contents = TIndex n1 }, TVar { contents = TIndex n2 } when n1 = n2 -> ()
+ | TProd ty_list1, TProd ty_list2 ->
+ if List.length ty_list1 <> List.length ty_list2 then
+ error (Result_arity_error (List.length ty_list1, List.length ty_list2));
+ List.iter2 unify ty_list1 ty_list2
+ | TVar ({ contents = TIndex n } as link), ty
+ | ty, TVar ({ contents = TIndex n } as link) ->
+ occur_check n ty;
+ link := TLink ty
+ | _ -> raise Unify
+
+let prod ty_list = match ty_list with
+ | [ty] -> ty
+ | _ -> TProd ty_list
+
+(* Typing of static exps *)
+let rec type_static_exp se = match se.se_desc with
+ | SInt _ | SVar _ -> STInt
+ | SBool _ -> STBool
+ | SBinOp((SAdd | SMinus | SMult | SDiv | SPower ), se1, se2) ->
+ expect_static_exp se1 STInt;
+ expect_static_exp se2 STInt;
+ STInt
+ | SBinOp((SEqual | SLess | SLeq | SGreater | SGeq), se1, se2) ->
+ expect_static_exp se1 STInt;
+ expect_static_exp se2 STInt;
+ STBool
+ | SIf (c, se1, se2) ->
+ expect_static_exp se1 STBool;
+ let ty1 = type_static_exp se1 in
+ expect_static_exp se2 ty1;
+ ty1
+
+and expect_static_exp se ty =
+ let found_ty = type_static_exp se in
+ if found_ty <> ty then
+ error (Static_type_error (found_ty, ty))
+
+let rec simplify_constr cl = match cl with
+ | [] -> []
+ | c::cl ->
+ let c' = simplify NameEnv.empty c in
+ match c'.se_desc with
+ | SBool true -> simplify_constr cl
+ | SBool false -> error (Static_constraint_false c)
+ | _ -> c::(simplify_constr cl)
+
+let rec find_simplification_one c = match c.se_desc with
+ | SBinOp(SEqual, { se_desc = SVar s }, se)
+ | SBinOp(SEqual, se, { se_desc = SVar s }) ->
+ Some (s, se)
+ | SIf(_, se1, { se_desc = SBool true })
+ | SIf(_, { se_desc = SBool true }, se1) ->
+ find_simplification_one se1
+ | _ -> None
+
+let rec find_simplification params cl = match cl with
+ | [] -> None, []
+ | c::cl ->
+ (match find_simplification_one c with
+ | Some (s, se) when not (List.mem s params) ->
+ Some (s, se), cl
+ | _ ->
+ let res, cl = find_simplification params cl in
+ res, c::cl)
+
+let solve_constr params cl =
+ let params = List.map (fun p -> p.p_name) params in
+ let subst_and_error env c =
+ let c' = subst env c in
+ match c'.se_desc with
+ | SBool false -> error (Static_constraint_false c)
+ | _ -> c'
+ in
+ let env = ref NameEnv.empty in
+ let rec solve_one cl =
+ let res, cl = find_simplification params cl in
+ match res with
+ | None -> cl
+ | Some (s, se) ->
+ env := NameEnv.add s se !env;
+ let cl = List.map (subst_and_error !env) cl in
+ solve_one cl
+ in
+ let cl = simplify_constr cl in
+ let cl = solve_one cl in
+ cl, !env
+
+(* Typing of expressions *)
+let rec type_exp env e =
+ try
+ let desc, ty = match e.e_desc with
+ | Econst (VBit _) -> e.e_desc, TBit
+ | Econst (VBitArray a) -> e.e_desc, TBitArray (mk_static_int (Array.length a))
+ | Evar id -> Evar id, IdentEnv.find id env
+ | Ereg e ->
+ let e = expect_exp env e TBit in
+ Ereg e, TBit
+ | Emem (MRom, addr_size, word_size, file, args) ->
+ (* addr_size > 0 *)
+ add_constraint (mk_static_exp (SBinOp (SLess, mk_static_int 0, addr_size)));
+ let read_addr = assert_1 args in
+ let read_addr = expect_exp env read_addr (TBitArray addr_size) in
+ Emem (MRom, addr_size, word_size, file, [read_addr]), TBitArray word_size
+ | Emem (MRam, addr_size, word_size, file, args) ->
+ (* addr_size > 0 *)
+ add_constraint (mk_static_exp (SBinOp (SLess, mk_static_int 0, addr_size)));
+ let read_addr, write_en, write_addr, data_in = assert_4 args in
+ let read_addr = expect_exp env read_addr (TBitArray addr_size) in
+ let write_addr = expect_exp env write_addr (TBitArray addr_size) in
+ let data_in = expect_exp env data_in (TBitArray word_size) in
+ let write_en = expect_exp env write_en TBit in
+ let args = [read_addr; write_en; write_addr; data_in] in
+ Emem (MRam, addr_size, word_size, file, args), TBitArray word_size
+ | Ecall (f, params, args) ->
+ let s = Modules.find_node f params in
+ (*check arity*)
+ if List.length s.s_inputs <> List.length args then
+ error (Args_arity_error (List.length args, List.length s.s_inputs));
+ (*check types of all arguments*)
+ let args = List.map2 (expect_exp env) args s.s_inputs in
+ List.iter add_constraint s.s_constraints;
+ Ecall(f, params, args), prod s.s_outputs
+ in
+ { e with e_desc = desc; e_ty = ty }, ty
+ with
+ | Typing_error k -> message e.e_loc k; raise Error
+
+and expect_exp env e ty =
+ let e, found_ty = type_exp env e in
+ try
+ unify ty found_ty;
+ e
+ with
+ Unify -> error (Type_error (found_ty, ty))
+
+let type_pat env pat = match pat with
+ | Evarpat x -> IdentEnv.find x env
+ | Etuplepat id_list -> prod (List.map (fun x -> IdentEnv.find x env) id_list)
+
+let type_eq env (pat, e) =
+ let pat_ty = type_pat env pat in
+ let e = expect_exp env e pat_ty in
+ (pat, e)
+
+let build env vds =
+ let build_one env vd = IdentEnv.add vd.v_ident vd.v_ty env in
+ List.fold_left build_one env vds
+
+let rec type_block env b = match b with
+ | BEqs(eqs, vds) ->
+ let vds = List.map (fun vd -> { vd with v_ty = fresh_type () }) vds in
+ let env = build env vds in
+ let eqs = List.map (type_eq env) eqs in
+ BEqs(eqs,vds)
+ | BIf(se, trueb, falseb) ->
+ expect_static_exp se STBool;
+ let prev_constr = get_constraints () in
+ let trueb = type_block env trueb in
+ let true_constr =
+ List.map (fun c -> mk_static_exp (SIf (se, c, mk_static_bool true))) (get_constraints ())
+ in
+ let falseb = type_block env falseb in
+ let false_constr =
+ List.map (fun c -> mk_static_exp (SIf (se, mk_static_bool true, c))) (get_constraints ())
+ in
+ set_constraints (prev_constr @ true_constr @ false_constr);
+ BIf(se, trueb, falseb)
+
+let ty_repr_block env b =
+ let static_exp funs acc se = simplify env se, acc in
+ let ty funs acc ty =
+ let ty = ty_repr ty in
+ (* go through types to substitute static exps *)
+ Mapfold.ty funs acc ty
+ in
+ let funs = { Mapfold.defaults with ty = ty; static_exp = static_exp } in
+ let b, _ = Mapfold.block_it funs () b in
+ b
+
+let node n =
+ try
+ Modules.add_node n [];
+ let env = build IdentEnv.empty n.n_inputs in
+ let env = build env n.n_outputs in
+ let body = type_block env n.n_body in
+ let constr = get_constraints () in
+ let constr, env = solve_constr n.n_params constr in
+ let body = ty_repr_block env body in
+ Modules.add_node n constr;
+ { n with n_body = body; n_constraints = constr }
+ with
+ Typing_error k -> message n.n_loc k; raise Error
+
+let program p =
+ let p_nodes = List.map node p.p_nodes in
+ { p with p_nodes = p_nodes }