diff options
Diffstat (limited to 'minijazz/src/analysis')
-rw-r--r-- | minijazz/src/analysis/callgraph.ml | 198 | ||||
-rw-r--r-- | minijazz/src/analysis/normalize.ml | 40 | ||||
-rw-r--r-- | minijazz/src/analysis/scoping.ml | 99 | ||||
-rw-r--r-- | minijazz/src/analysis/simplify.ml | 42 | ||||
-rw-r--r-- | minijazz/src/analysis/typing.ml | 375 |
5 files changed, 754 insertions, 0 deletions
diff --git a/minijazz/src/analysis/callgraph.ml b/minijazz/src/analysis/callgraph.ml new file mode 100644 index 0000000..d1fab8f --- /dev/null +++ b/minijazz/src/analysis/callgraph.ml @@ -0,0 +1,198 @@ +open Ast +open Mapfold +open Static +open Static_utils +open Location +open Errors + +(** Inlines all nodes with static paramaters. *) + +let expect_bool env se = + let se = simplify env se in + match se.se_desc with + | SBool v -> v + | _ -> Format.eprintf "Expected a boolean@."; raise Error + +let expect_int env se = + let se = simplify env se in + match se.se_desc with + | SInt v -> v + | _ -> Format.eprintf "Expected an integer@."; raise Error + +let simplify_ty env ty = match ty with + | TBitArray se -> TBitArray (simplify env se) + | _ -> ty + +(** Find a node by name*) +let nodes_list = ref [] +let find_node f = + List.find (fun n -> f = n.n_name) !nodes_list + +let vars_of_pat env pat = + let exp_of_ident id = + try + let ty = IdentEnv.find id env in + mk_exp ~ty:ty (Evar id) + with + | Not_found -> Format.eprintf "Not in env: %a@." Ident.print_ident id; assert false + in + let rec _vars_of_pat acc pat = match pat with + | Evarpat id -> (exp_of_ident id)::acc + | Etuplepat l -> List.fold_left (fun acc id -> (exp_of_ident id)::acc) acc l + in + _vars_of_pat [] pat + +let ident_of_exp e = match e.e_desc with + | Evar x -> x + | _ -> assert false + +let rename env vd = + let e = mk_exp ~ty:vd.v_ty (Evar (Ident.copy vd.v_ident)) in + IdentEnv.add vd.v_ident e env + +let build_params m names values = + List.fold_left2 (fun m { p_name = n } v -> NameEnv.add n v m) m names values + +let build_exp m vds values = + List.fold_left2 (fun m { v_ident = n } e -> IdentEnv.add n e m) m vds values + +let build_env env vds = + List.fold_left (fun env vd -> IdentEnv.add vd.v_ident vd.v_ty env) env vds + +let rec find_local_vars b = match b with + | BEqs (_, vds) -> vds + | BIf (_, trueb, falseb) -> (find_local_vars trueb) @ (find_local_vars falseb) + +(** Substitutes idents with new names, static params with their values *) +let do_subst_block m subst b = + let translate_ident subst id = + try + ident_of_exp (IdentEnv.find id subst) + with + | Not_found -> id + in + let static_exp funs (subst, m) se = + simplify m se, (subst, m) + in + let exp funs (subst, m) e = + let e, _ = Mapfold.exp funs (subst, m) e in + match e.e_desc with + | Evar x -> + let e = if IdentEnv.mem x subst then IdentEnv.find x subst else e in + e, (subst, m) + | _ -> Mapfold.exp funs (subst, m) e + in + let pat funs (subst, m) pat = match pat with + | Evarpat id -> Evarpat (translate_ident subst id), (subst, m) + | Etuplepat ids -> Etuplepat (List.map (translate_ident subst) ids), (subst, m) + in + let var_dec funs (subst, m) vd = + (* iterate on the type *) + let vd, _ = Mapfold.var_dec funs (subst, m) vd in + { vd with v_ident = translate_ident subst vd.v_ident }, (subst, m) + in + let funs = + { Mapfold.defaults with static_exp = static_exp; exp = exp; + pat = pat; var_dec = var_dec } + in + let b, _ = Mapfold.block_it funs (subst, m) b in + b + +let check_params loc m param_names params cl = + let env = build_params NameEnv.empty param_names params in + let cl = List.map (simplify env) cl in + try + check_true m cl + with Unsatisfiable(c) -> + Format.eprintf "%aThe following constraint is not satisfied: %a@." + print_location loc Printer.print_static_exp c; + raise Error + +let rec inline_node loc env m call_stack f params args pat = + (* Check that the definition is sound *) + if List.mem (f, params) call_stack then ( + Format.eprintf "The definition of %s is circular.@." f; + raise Error + ); + let call_stack = (f, params)::call_stack in + + (* do the actual work *) + let n = find_node f in + check_params loc m n.n_params params n.n_constraints; + let m = build_params m n.n_params params in + let subst = build_exp IdentEnv.empty n.n_inputs args in + let subst = build_exp subst n.n_outputs (List.rev (vars_of_pat env pat)) in + let locals = find_local_vars n.n_body in + let subst = List.fold_left rename subst locals in + let b = do_subst_block m subst n.n_body in + let b = Normalize.block b in + b, call_stack + +and translate_eq env m subst call_stack (eqs, vds) ((pat, e) as eq) = + match e.e_desc with + (* Inline all nodes or only those with params or declared inline + if no_inline_all = true *) + | Ecall(f, params, args) -> + (try + let n = find_node f in + if not !Cli_options.no_inline_all + || not (Misc.is_empty params) + || n.n_inlined = Inlined then + let params = List.map (simplify m) params in + let b, call_stack = inline_node e.e_loc env m call_stack f params args pat in + let new_eqs, new_vds = translate_block env m subst call_stack b in + new_eqs@eqs, new_vds@vds + else + eq::eqs, vds + with + | Not_found -> eq::eqs, vds (* Predefined function*) + ) + | _ -> eq::eqs, vds + +and translate_eqs env m subst call_stack acc eqs = + List.fold_left (translate_eq env m subst call_stack) acc eqs + +and translate_block env m subst call_stack b = + match b with + | BEqs (eqs, vds) -> + let vds = List.map (fun vd -> { vd with v_ty = simplify_ty m vd.v_ty }) vds in + let env = build_env env vds in + translate_eqs env m subst call_stack ([], vds) eqs + | BIf(se, trueb, elseb) -> + if expect_bool m se then + translate_block env m subst call_stack trueb + else + translate_block env m subst call_stack elseb + +let node m n = + (*Init state*) + let call_stack = [(n.n_name, [])] in + (*Do the translation*) + let env = build_env IdentEnv.empty n.n_inputs in + let env = build_env env n.n_outputs in + let eqs, vds = translate_block env m IdentEnv.empty call_stack n.n_body in + { n with n_body = BEqs (eqs, vds) } + +let build_cd env cd = + NameEnv.add cd.c_name cd.c_value env + +let program p = + nodes_list := p.p_nodes; + let m = List.fold_left build_cd NameEnv.empty p.p_consts in + if !Cli_options.no_inline_all then ( + (* Find the nodes without static parameters *) + let nodes = List.filter (fun n -> Misc.is_empty n.n_params) p.p_nodes in + let nodes = List.map (fun n -> node m n) nodes in + { p with p_nodes = nodes } + ) else ( + try + let n = List.find (fun n -> n.n_name = !Cli_options.main_node) p.p_nodes in + if n.n_params <> [] then ( + Format.eprintf "The main node '%s' cannot have static parameters@." n.n_name; + raise Error + ); + { p with p_nodes = [node m n] } + with Not_found -> + Format.eprintf "Cannot find the main node '%s'@." !Cli_options.main_node; + raise Error + ) diff --git a/minijazz/src/analysis/normalize.ml b/minijazz/src/analysis/normalize.ml new file mode 100644 index 0000000..52db539 --- /dev/null +++ b/minijazz/src/analysis/normalize.ml @@ -0,0 +1,40 @@ +open Ast +open Mapfold + +let mk_eq e = + let id = Ident.fresh_ident "_l" in + let eq = (Evarpat id, e) in + let vd = mk_var_dec id e.e_ty in + Evar id, vd, eq + +(* Put all the arguments in separate equations *) +let exp funs (eqs, vds) e = match e.e_desc with + | Econst _ | Evar _ -> e, (eqs, vds) + | _ -> + let e, (eqs, vds) = Mapfold.exp funs (eqs, vds) e in + let desc, vd, eq = mk_eq e in + { e with e_desc = desc }, (eq::eqs, vd::vds) + +let equation funs (eqs, vds) (pat, e) = + match e.e_desc with + | Econst _ | Evar _ -> (pat, e), (eqs, vds) + | _ -> + let _, ((_, e)::eqs, _::vds) = Mapfold.exp_it funs (eqs, vds) e in + (pat, e), (eqs, vds) + +let block funs acc b = match b with + | BEqs(eqs, vds) -> + let eqs, (new_eqs, new_vds) = Misc.mapfold (Mapfold.equation_it funs) ([], []) eqs in + BEqs(new_eqs@eqs, new_vds@vds), acc + | BIf _ -> raise Mapfold.Fallback + +let program p = + let funs = { Mapfold.defaults with exp = exp; equation = equation; block = block } in + let p, _ = Mapfold.program_it funs ([], []) p in + p + +(* Used by Callgraph *) +let block b = + let funs = { Mapfold.defaults with exp = exp; equation = equation; block = block } in + let b, _ = Mapfold.block_it funs ([], []) b in + b diff --git a/minijazz/src/analysis/scoping.ml b/minijazz/src/analysis/scoping.ml new file mode 100644 index 0000000..0fefd5b --- /dev/null +++ b/minijazz/src/analysis/scoping.ml @@ -0,0 +1,99 @@ +open Ast +open Mapfold +open Static +open Static_utils +open Location +open Errors + +(** Simplifies static expression in the program. *) +let simplify_program p = + let const_dec funs cenv cd = + let v = subst cenv cd.c_value in + let cenv = NameEnv.add cd.c_name v cenv in + { cd with c_value = v }, cenv + in + let static_exp funs cenv se = + let se = subst cenv se in + (match se.se_desc with + | SVar id -> + (* Constants with se.se_loc = no_location are generated and should not be checked *) + if not (NameEnv.mem id cenv) && not (se.se_loc == no_location) then ( + Format.eprintf "%aThe constant name '%s' is unbound@." + print_location se.se_loc id; + raise Error + ) + | _ -> () + ); + se, cenv + in + let node_dec funs cenv nd = + let cenv' = + List.fold_left + (fun cenv p -> NameEnv.add p.p_name (mk_static_var p.p_name) cenv) + cenv nd.n_params + in + let nd, _ = Mapfold.node_dec funs cenv' nd in + nd, cenv + in + let funs = + { Mapfold.defaults with const_dec = const_dec; + static_exp = static_exp; node_dec = node_dec } + in + let p, _ = Mapfold.program_it funs NameEnv.empty p in + p + +(** Checks the name used in the program are defined. + Adds var_decs for all variables defined in a block. *) +let check_names p = + let rec pat_vars s pat = match pat with + | Evarpat id -> IdentSet.add id s + | Etuplepat ids -> List.fold_left (fun s id -> IdentSet.add id s) s ids + in + let build_set vds = + List.fold_left (fun s vd -> IdentSet.add vd.v_ident s) IdentSet.empty vds + in + let block funs (s, _) b = match b with + | BEqs(eqs, _) -> + let defnames = List.fold_left (fun s (pat, _) -> pat_vars s pat) IdentSet.empty eqs in + let ls = IdentSet.diff defnames s in (* remove outputs from the set *) + let vds = IdentSet.fold (fun id l -> (mk_var_dec id invalid_type)::l) ls [] in + let new_s = IdentSet.union s defnames in + let eqs,_ = Misc.mapfold (Mapfold.equation_it funs) (new_s, IdentSet.empty) eqs in + BEqs (eqs, vds), (s, defnames) + | BIf(se, trueb, falseb) -> + let trueb, (_, def_true) = Mapfold.block_it funs (s, IdentSet.empty) trueb in + let falseb, (_, def_false) = Mapfold.block_it funs (s, IdentSet.empty) falseb in + let defnames = IdentSet.inter def_true def_false in + BIf(se, trueb, falseb), (s, defnames) + in + let exp funs (s, defnames) e = match e.e_desc with + | Evar id -> + if not (IdentSet.mem id s) then ( + Format.eprintf "%aThe identifier '%a' is unbound@." + print_location e.e_loc Ident.print_ident id; + raise Error + ); + e, (s, defnames) + | _ -> Mapfold.exp funs (s, defnames) e + in + let node n = + let funs = { Mapfold.defaults with block = block; exp = exp } in + let s = build_set (n.n_inputs@n.n_outputs) in + let n_body, (_, defnames) = Mapfold.block_it funs (s, IdentSet.empty) n.n_body in + (* check for undefined outputs *) + let undefined_outputs = + List.filter (fun vd -> not (IdentSet.mem vd.v_ident defnames)) n.n_outputs + in + if undefined_outputs <> [] then ( + Format.eprintf "%aThe following outputs are not defined: %a@." + print_location n.n_loc Printer.print_var_decs undefined_outputs; + raise Error + ); + { n with n_body = n_body } + in + { p with p_nodes = List.map node p.p_nodes } + + +let program p = + let p = simplify_program p in + check_names p diff --git a/minijazz/src/analysis/simplify.ml b/minijazz/src/analysis/simplify.ml new file mode 100644 index 0000000..f7cbb5c --- /dev/null +++ b/minijazz/src/analysis/simplify.ml @@ -0,0 +1,42 @@ +open Ast +open Static + +let is_not_zero ty = match ty with + | TBitArray { se_desc = SInt 0 } -> false + | _ -> true + +let rec simplify_exp e = match e.e_desc with + (* replace x[i..j] with [] if j < i *) + | Ecall("slice", + [{ se_desc = SInt min }; + { se_desc = SInt max }; n], _) when max < min -> + { e with e_desc = Econst (VBitArray (Array.make 0 false)) } + (* replace x[i..i] with x[i] *) + | Ecall("slice", [min; max; n], args) when min = max -> + let new_e = { e with e_desc = Ecall("select", [min; n], args) } in + simplify_exp new_e + (* replace x.[] or [].x with x *) + | Ecall("concat", _, [{ e_ty = TBitArray { se_desc = SInt 0 } }; e1]) + | Ecall("concat", _, [e1; { e_ty = TBitArray { se_desc = SInt 0 } }]) -> + e1 + | Ecall(f, params, args) -> + { e with e_desc = Ecall(f, params, List.map simplify_exp args) } + | _ -> e + +let simplify_eq (pat,e) = + (pat, simplify_exp e) + +let rec block b = match b with + | BEqs(eqs, vds) -> + let eqs = List.map simplify_eq eqs in + (* remove variables with size 0 *) + let vds = List.filter (fun vd -> is_not_zero vd.v_ty) vds in + let eqs = List.filter (fun (_, e) -> is_not_zero e.e_ty) eqs in + BEqs(eqs, vds) + | BIf(se, trueb, elseb) -> BIf(se, block trueb, block elseb) + +let node n = + { n with n_body = block n.n_body } + +let program p = + { p with p_nodes = List.map node p.p_nodes } diff --git a/minijazz/src/analysis/typing.ml b/minijazz/src/analysis/typing.ml new file mode 100644 index 0000000..0212b95 --- /dev/null +++ b/minijazz/src/analysis/typing.ml @@ -0,0 +1,375 @@ +open Ast +open Static +open Static_utils +open Printer +open Errors +open Misc +open Mapfold + +exception Unify + +type error_kind = + | Args_arity_error of int * int + | Params_arity_error of int * int + | Result_arity_error of int * int + | Type_error of ty * ty + | Static_type_error of static_ty * static_ty + | Static_constraint_false of static_exp + +exception Typing_error of error_kind + +let error k = raise (Typing_error k) + +let message loc err = match err with + | Args_arity_error (found, expected) -> + Format.eprintf "%aWrong number of arguments (found '%d'; expected '%d')@." + Location.print_location loc found expected + | Result_arity_error (found, expected) -> + Format.eprintf "%aWrong number of outputs (found '%d'; expected '%d')@." + Location.print_location loc found expected + | Params_arity_error (found, expected) -> + Format.eprintf "%aWrong number of static parameters (found '%d'; expected '%d')@." + Location.print_location loc found expected + | Type_error (found_ty, exp_ty) -> + Format.eprintf "%aThis expression has type '%a' but '%a' was expected@." + Location.print_location loc print_type found_ty print_type exp_ty + | Static_type_error (found_ty, exp_ty) -> + Format.eprintf "%aThis static expression has type '%a' but '%a' was expected@." + Location.print_location loc print_static_type found_ty print_static_type exp_ty + | Static_constraint_false se -> + Format.eprintf "%aThe following constraint is not satisfied: %a@." + Location.print_location loc print_static_exp se + + +type signature = + { s_inputs : ty list; + s_outputs : ty list; + s_params : name list; + s_constraints : static_exp list } + +module Modules = struct + let env = ref Ast.NameEnv.empty + + let add_sig ?(params = []) ?(constr = []) n inp outp = + let s = { s_inputs = inp; s_outputs = outp; s_params = params; s_constraints = constr } in + env := Ast.NameEnv.add n s !env + + let _ = + add_sig "and" [TBit;TBit] [TBit]; + add_sig "xor" [TBit;TBit] [TBit]; + add_sig "or" [TBit;TBit] [TBit]; + add_sig "not" [TBit] [TBit]; + add_sig "reg" [TBit] [TBit]; + add_sig "mux" [TBit;TBit;TBit] [TBit]; + add_sig ~params:["n"] "print" [TBitArray (mk_static_var "n"); TBit] [TBit]; + add_sig ~params:["n"] "input" [TBit] [TBitArray (mk_static_var "n")]; + let constr1 = mk_static_exp (SBinOp(SLess, mk_static_var "i", mk_static_var "n")) in + let constr2 = mk_static_exp (SBinOp(SLeq, mk_static_int 0, mk_static_var "i")) in + add_sig ~params:["i"; "n"] + ~constr:[constr1; constr2] + "select" [TBitArray (mk_static_var "n")] [TBit]; + let add = mk_static_exp (SBinOp(SAdd, mk_static_var "n1", mk_static_var "n2")) in + add_sig ~params:["n1"; "n2"; "n3"] + ~constr:[mk_static_exp (SBinOp (SEqual, mk_static_var "n3", add))] + "concat" [TBitArray (mk_static_var "n1"); TBitArray (mk_static_var "n2")] + [TBitArray (mk_static_var "n3")]; + (* slice : size = min <= max ? max - min + 1 : 0 *) + let size = + mk_static_exp + (SBinOp(SAdd, + mk_static_exp (SBinOp(SMinus, mk_static_var "max", mk_static_var "min")), + mk_static_int 1)) + in + let size = + mk_static_exp (SIf (mk_static_exp (SBinOp(SLeq, mk_static_var "min", mk_static_var "max")), + size, mk_static_int 0)) + in + let constr1 = mk_static_exp (SBinOp(SLeq, mk_static_int 0, mk_static_var "min")) in + let constr2 = mk_static_exp (SBinOp(SLess, mk_static_var "max", mk_static_var "n")) in + add_sig ~params:["min"; "max"; "n"] ~constr:[constr1; constr2] "slice" + [TBitArray (mk_static_var "n")] [TBitArray size] + + + let tys_of_vds vds = List.map (fun vd -> vd.v_ty) vds + + let add_node n constr = + let s = { s_inputs = tys_of_vds n.n_inputs; + s_outputs = tys_of_vds n.n_outputs; + s_params = List.map (fun p -> p.p_name) n.n_params; + s_constraints = constr } in + env := Ast.NameEnv.add n.n_name s !env + + let build_param_env param_names params = + List.fold_left2 + (fun env pn p -> NameEnv.add pn p env) + NameEnv.empty param_names params + + let subst_ty env ty = match ty with + | TBitArray se -> TBitArray (subst env se) + | _ -> ty + + let find_node n params = + try + let s = Ast.NameEnv.find n !env in + if List.length s.s_params <> List.length params then + error (Params_arity_error (List.length params, List.length s.s_params)); + let env = build_param_env s.s_params params in + let s = + { s with s_inputs = List.map (subst_ty env) s.s_inputs; + s_outputs = List.map (subst_ty env) s.s_outputs; + s_constraints = List.map (subst env) s.s_constraints } + in + s + with Not_found -> + Format.eprintf "Unbound node '%s'@." n; + raise Error +end + +let constr_list = ref [] +let add_constraint se = + constr_list := se :: !constr_list +let set_constraints cl = + constr_list := cl +let get_constraints () = + let v = !constr_list in + constr_list := []; v + +let fresh_static_var () = + SVar ("s_"^(Misc.gen_symbol ())) + +(* Functions on types*) + +let fresh_type = + let index = ref 0 in + let gen_index () = (incr index; !index) in + let fresh_type () = TVar (ref (TIndex (gen_index ()))) in + fresh_type + +(** returns the canonic (short) representant of [ty] + and update it to this value. *) +let rec ty_repr ty = match ty with + | TVar link -> + (match !link with + | TLink ty -> + let ty = ty_repr ty in + link := TLink ty; + ty + | _ -> ty) + | _ -> ty + +(** verifies that index is fresh in ck. *) +let rec occur_check index ty = + let ty = ty_repr ty in + match ty with + | TUnit | TBit | TBitArray _ -> () + | TVar { contents = TIndex n } when index <> n -> () + | TProd ty_list -> List.iter (occur_check index) ty_list + | _ -> raise Unify + +let rec unify ty1 ty2 = + let ty1 = ty_repr ty1 in + let ty2 = ty_repr ty2 in + if ty1 == ty2 then () + else + match (ty1, ty2) with + | TBitArray n, TBit | TBit, TBitArray n -> + add_constraint (mk_static_exp (SBinOp(SEqual, n, mk_static_int 1))) + | TBitArray n1, TBitArray n2 -> + add_constraint (mk_static_exp (SBinOp(SEqual, n1, n2))) + | TVar { contents = TIndex n1 }, TVar { contents = TIndex n2 } when n1 = n2 -> () + | TProd ty_list1, TProd ty_list2 -> + if List.length ty_list1 <> List.length ty_list2 then + error (Result_arity_error (List.length ty_list1, List.length ty_list2)); + List.iter2 unify ty_list1 ty_list2 + | TVar ({ contents = TIndex n } as link), ty + | ty, TVar ({ contents = TIndex n } as link) -> + occur_check n ty; + link := TLink ty + | _ -> raise Unify + +let prod ty_list = match ty_list with + | [ty] -> ty + | _ -> TProd ty_list + +(* Typing of static exps *) +let rec type_static_exp se = match se.se_desc with + | SInt _ | SVar _ -> STInt + | SBool _ -> STBool + | SBinOp((SAdd | SMinus | SMult | SDiv | SPower ), se1, se2) -> + expect_static_exp se1 STInt; + expect_static_exp se2 STInt; + STInt + | SBinOp((SEqual | SLess | SLeq | SGreater | SGeq), se1, se2) -> + expect_static_exp se1 STInt; + expect_static_exp se2 STInt; + STBool + | SIf (c, se1, se2) -> + expect_static_exp se1 STBool; + let ty1 = type_static_exp se1 in + expect_static_exp se2 ty1; + ty1 + +and expect_static_exp se ty = + let found_ty = type_static_exp se in + if found_ty <> ty then + error (Static_type_error (found_ty, ty)) + +let rec simplify_constr cl = match cl with + | [] -> [] + | c::cl -> + let c' = simplify NameEnv.empty c in + match c'.se_desc with + | SBool true -> simplify_constr cl + | SBool false -> error (Static_constraint_false c) + | _ -> c::(simplify_constr cl) + +let rec find_simplification_one c = match c.se_desc with + | SBinOp(SEqual, { se_desc = SVar s }, se) + | SBinOp(SEqual, se, { se_desc = SVar s }) -> + Some (s, se) + | SIf(_, se1, { se_desc = SBool true }) + | SIf(_, { se_desc = SBool true }, se1) -> + find_simplification_one se1 + | _ -> None + +let rec find_simplification params cl = match cl with + | [] -> None, [] + | c::cl -> + (match find_simplification_one c with + | Some (s, se) when not (List.mem s params) -> + Some (s, se), cl + | _ -> + let res, cl = find_simplification params cl in + res, c::cl) + +let solve_constr params cl = + let params = List.map (fun p -> p.p_name) params in + let subst_and_error env c = + let c' = subst env c in + match c'.se_desc with + | SBool false -> error (Static_constraint_false c) + | _ -> c' + in + let env = ref NameEnv.empty in + let rec solve_one cl = + let res, cl = find_simplification params cl in + match res with + | None -> cl + | Some (s, se) -> + env := NameEnv.add s se !env; + let cl = List.map (subst_and_error !env) cl in + solve_one cl + in + let cl = simplify_constr cl in + let cl = solve_one cl in + cl, !env + +(* Typing of expressions *) +let rec type_exp env e = + try + let desc, ty = match e.e_desc with + | Econst (VBit _) -> e.e_desc, TBit + | Econst (VBitArray a) -> e.e_desc, TBitArray (mk_static_int (Array.length a)) + | Evar id -> Evar id, IdentEnv.find id env + | Ereg e -> + let e = expect_exp env e TBit in + Ereg e, TBit + | Emem (MRom, addr_size, word_size, file, args) -> + (* addr_size > 0 *) + add_constraint (mk_static_exp (SBinOp (SLess, mk_static_int 0, addr_size))); + let read_addr = assert_1 args in + let read_addr = expect_exp env read_addr (TBitArray addr_size) in + Emem (MRom, addr_size, word_size, file, [read_addr]), TBitArray word_size + | Emem (MRam, addr_size, word_size, file, args) -> + (* addr_size > 0 *) + add_constraint (mk_static_exp (SBinOp (SLess, mk_static_int 0, addr_size))); + let read_addr, write_en, write_addr, data_in = assert_4 args in + let read_addr = expect_exp env read_addr (TBitArray addr_size) in + let write_addr = expect_exp env write_addr (TBitArray addr_size) in + let data_in = expect_exp env data_in (TBitArray word_size) in + let write_en = expect_exp env write_en TBit in + let args = [read_addr; write_en; write_addr; data_in] in + Emem (MRam, addr_size, word_size, file, args), TBitArray word_size + | Ecall (f, params, args) -> + let s = Modules.find_node f params in + (*check arity*) + if List.length s.s_inputs <> List.length args then + error (Args_arity_error (List.length args, List.length s.s_inputs)); + (*check types of all arguments*) + let args = List.map2 (expect_exp env) args s.s_inputs in + List.iter add_constraint s.s_constraints; + Ecall(f, params, args), prod s.s_outputs + in + { e with e_desc = desc; e_ty = ty }, ty + with + | Typing_error k -> message e.e_loc k; raise Error + +and expect_exp env e ty = + let e, found_ty = type_exp env e in + try + unify ty found_ty; + e + with + Unify -> error (Type_error (found_ty, ty)) + +let type_pat env pat = match pat with + | Evarpat x -> IdentEnv.find x env + | Etuplepat id_list -> prod (List.map (fun x -> IdentEnv.find x env) id_list) + +let type_eq env (pat, e) = + let pat_ty = type_pat env pat in + let e = expect_exp env e pat_ty in + (pat, e) + +let build env vds = + let build_one env vd = IdentEnv.add vd.v_ident vd.v_ty env in + List.fold_left build_one env vds + +let rec type_block env b = match b with + | BEqs(eqs, vds) -> + let vds = List.map (fun vd -> { vd with v_ty = fresh_type () }) vds in + let env = build env vds in + let eqs = List.map (type_eq env) eqs in + BEqs(eqs,vds) + | BIf(se, trueb, falseb) -> + expect_static_exp se STBool; + let prev_constr = get_constraints () in + let trueb = type_block env trueb in + let true_constr = + List.map (fun c -> mk_static_exp (SIf (se, c, mk_static_bool true))) (get_constraints ()) + in + let falseb = type_block env falseb in + let false_constr = + List.map (fun c -> mk_static_exp (SIf (se, mk_static_bool true, c))) (get_constraints ()) + in + set_constraints (prev_constr @ true_constr @ false_constr); + BIf(se, trueb, falseb) + +let ty_repr_block env b = + let static_exp funs acc se = simplify env se, acc in + let ty funs acc ty = + let ty = ty_repr ty in + (* go through types to substitute static exps *) + Mapfold.ty funs acc ty + in + let funs = { Mapfold.defaults with ty = ty; static_exp = static_exp } in + let b, _ = Mapfold.block_it funs () b in + b + +let node n = + try + Modules.add_node n []; + let env = build IdentEnv.empty n.n_inputs in + let env = build env n.n_outputs in + let body = type_block env n.n_body in + let constr = get_constraints () in + let constr, env = solve_constr n.n_params constr in + let body = ty_repr_block env body in + Modules.add_node n constr; + { n with n_body = body; n_constraints = constr } + with + Typing_error k -> message n.n_loc k; raise Error + +let program p = + let p_nodes = List.map node p.p_nodes in + { p with p_nodes = p_nodes } |