From bd04128b033c8a623cceca31de072053837ad888 Mon Sep 17 00:00:00 2001 From: Alex AUVOLAT Date: Fri, 8 Nov 2013 22:43:54 +0100 Subject: More optimizations. --- sched/netlist_ast.ml | 4 +- sched/netlist_dumb.ml | 105 +++------------------------------------------ sched/netlist_parser.mly | 6 +-- sched/netlist_printer.ml | 10 ++--- sched/scheduler.ml | 19 +++++++++ sched/simplify.ml | 109 ++++++++++++++++++++++++++++++----------------- tests/nadder.mj | 6 ++- 7 files changed, 107 insertions(+), 152 deletions(-) diff --git a/sched/netlist_ast.ml b/sched/netlist_ast.ml index ae16888..1866ed3 100644 --- a/sched/netlist_ast.ml +++ b/sched/netlist_ast.ml @@ -10,8 +10,8 @@ module Env = struct List.fold_left (fun env (x, ty) -> add x ty env) empty l end -type ty = TBit | TBitArray of int -type value = VBit of bool | VBitArray of bool array +type ty = int (* just one for a bit... *) +type value = bool array type binop = Or | Xor | And | Nand diff --git a/sched/netlist_dumb.ml b/sched/netlist_dumb.ml index 1e2a57f..6b8f526 100644 --- a/sched/netlist_dumb.ml +++ b/sched/netlist_dumb.ml @@ -49,9 +49,8 @@ let mkbinstr a = done; r -let const_info = function - | VBit(a) -> "$" ^ (mkbinstr [|a|]), 1, [|a|] - | VBitArray(a) -> "$" ^ (mkbinstr a), Array.length a, a +let const_info a = + "$" ^ (mkbinstr a), Array.length a, a let make_program_dumb p = (* @@ -96,13 +95,9 @@ let make_program_dumb p = (* Make ids for variables *) Env.iter (fun k v -> - let sz = match v with - | TBit -> 1 - | TBitArray(n) -> n - in - vars := { name = k; size = sz }::(!vars); - Hashtbl.add var_map k (!next_id); - next_id := !next_id + 1) + vars := { name = k; size = v }::(!vars); + Hashtbl.add var_map k (!next_id); + next_id := !next_id + 1) p.p_vars; let var_id = Hashtbl.find var_map in @@ -229,93 +224,3 @@ let print_dumb_program oc p = let print_program oc p = print_dumb_program oc (make_program_dumb p) - -(* OLD PRINTER CODE *) - - -(* - -(* constants *) -let c_arg = 0 -let c_reg = 1 -let c_not = 2 -let c_binop = 3 -let c_mux = 4 -let c_rom = 5 -let c_ram = 6 -let c_concat = 7 -let c_slice = 8 -let c_select = 9 - -let print_program oc p = - let ff = formatter_of_out_channel oc in - (* associate numbers to variables *) - let n_vars = Env.fold (fun _ _ n -> n+1) p.p_vars 0 in - let n = ref 0 in - let var_id = Hashtbl.create n_vars in - fprintf ff "%d\n" n_vars; - Env.iter - (fun k v -> - Hashtbl.add var_id k !n; - fprintf ff "%d %s\n" - (match v with - | TBit -> 1 - | TBitArray(n) -> n) - k; - n := !n + 1) - p.p_vars; - (* write input vars *) - fprintf ff "%d" (List.length p.p_inputs); - List.iter (fun k -> fprintf ff " %d" (Hashtbl.find var_id k)) p.p_inputs; - fprintf ff "\n"; - (* write output vars *) - fprintf ff "%d" (List.length p.p_outputs); - List.iter (fun k -> fprintf ff " %d" (Hashtbl.find var_id k)) p.p_outputs; - fprintf ff "\n"; - (* write equations *) - fprintf ff "%d\n" (List.length p.p_eqs); - (* write equations *) - let print_arg = function - | Avar(k) -> fprintf ff " $%d" (Hashtbl.find var_id k) - | Aconst(n) -> fprintf ff " "; - begin match n with - | VBit(x) -> fprintf ff "%d" (if x then 1 else 0) - | VBitArray(a) -> - for i = 0 to Array.length a - 1 do - fprintf ff "%d" (if a.(i) then 1 else 0) - done - end - in - List.iter - (fun (k, eqn) -> - fprintf ff "%d " (Hashtbl.find var_id k); - begin match eqn with - | Earg(a) -> fprintf ff "%d" c_arg; - print_arg a - | Ereg(i) -> fprintf ff "%d %d" c_reg (Hashtbl.find var_id i) - | Enot(a) -> fprintf ff "%d" c_not; - print_arg a - | Ebinop(o, a, b) -> fprintf ff "%d %d" c_binop (binop_i o); - print_arg a; - print_arg b - | Emux(a, b, c) -> fprintf ff "%d" c_mux; - print_arg a; print_arg b; print_arg c - | Erom(u, v, a) -> fprintf ff "%d %d %d" c_rom u v; - print_arg a - | Eram (u, v, a, b, c, d) -> fprintf ff "%d %d %d" c_ram u v; - print_arg a; print_arg b; print_arg c; print_arg d - | Econcat(a, b) -> fprintf ff "%d" c_concat; - print_arg a; print_arg b - | Eslice(u, v, a) -> fprintf ff "%d %d %d" c_slice u v; - print_arg a - | Eselect(i, a) -> fprintf ff "%d %d" c_select i; - print_arg a - end; - fprintf ff "\n") - p.p_eqs; - (* flush *) - fprintf ff "@." - -*) - - diff --git a/sched/netlist_parser.mly b/sched/netlist_parser.mly index 1f76528..66b4eab 100644 --- a/sched/netlist_parser.mly +++ b/sched/netlist_parser.mly @@ -6,7 +6,7 @@ for i = 0 to String.length n - 1 do if n.[i] = '1' then ret.(i) <- true done; - VBitArray(ret) + ret %} @@ -56,5 +56,5 @@ arg: var: x=NAME ty=ty_exp { (x, ty) } ty_exp: - | /*empty*/ { TBit } - | COLON n=INT { TBitArray (int_of_string n) } + | /*empty*/ { 1 } + | COLON n=INT { int_of_string n } diff --git a/sched/netlist_printer.ml b/sched/netlist_printer.ml index 547a0be..2c80d70 100644 --- a/sched/netlist_printer.ml +++ b/sched/netlist_printer.ml @@ -19,9 +19,8 @@ let rec print_list print lp sep rp ff = function List.iter (fprintf ff "%s %a" sep print) l; fprintf ff "%s" rp -let print_ty ff ty = match ty with - | TBit -> () - | TBitArray n -> fprintf ff " : %d" n +let print_ty ff n = + fprintf ff " : %d" n let print_bool ff b = if b then @@ -29,9 +28,8 @@ let print_bool ff b = else fprintf ff "0" -let print_value ff v = match v with - | VBit b -> print_bool ff b - | VBitArray a -> Array.iter (print_bool ff) a +let print_value ff a = + Array.iter (print_bool ff) a let print_arg ff arg = match arg with | Aconst v -> print_value ff v diff --git a/sched/scheduler.ml b/sched/scheduler.ml index 34ce3aa..d079f64 100644 --- a/sched/scheduler.ml +++ b/sched/scheduler.ml @@ -23,6 +23,25 @@ let read_exp eq = in aux eq +let read_exp_all eq = + let add_arg x l = match x with + | Avar(f) -> f::l + | Aconst(_) -> l + in + let aux = function + | Earg(x) -> add_arg x [] + | Ereg(i) -> [i] + | Enot(x) -> add_arg x [] + | Ebinop(_, x, y) -> add_arg x (add_arg y []) + | Emux(a, b, c) -> add_arg a (add_arg b (add_arg c [])) + | Erom(_, _, a) -> add_arg a [] + | Eram(_, _, a, b, c, d) -> add_arg a (add_arg b (add_arg c (add_arg d []))) + | Econcat(u, v) -> add_arg u (add_arg v []) + | Eslice(_, _, a) -> add_arg a [] + | Eselect(_, a) -> add_arg a [] + in + aux eq + let prog_eq_map p = List.fold_left (fun x (vn, eqn) -> Smap.add vn eqn x) diff --git a/sched/simplify.ml b/sched/simplify.ml index 4f2359e..db8125b 100644 --- a/sched/simplify.ml +++ b/sched/simplify.ml @@ -69,37 +69,29 @@ let arith_simplify p = (fun (n, eq) -> let useless = ref false in let neq = match eq with - | Ebinop(Or, Aconst(VBit(false)), x) -> Earg(x) - | Ebinop(Or, Aconst(VBit(true)), x) -> Earg(Aconst(VBit(true))) - | Ebinop(Or, x, Aconst(VBit(false))) -> Earg(x) - | Ebinop(Or, x, Aconst(VBit(true))) -> Earg(Aconst(VBit(true))) + | Ebinop(Or, Aconst([|false|]), x) -> Earg(x) + | Ebinop(Or, Aconst([|true|]), x) -> Earg(Aconst([|true|])) + | Ebinop(Or, x, Aconst([|false|])) -> Earg(x) + | Ebinop(Or, x, Aconst([|true|])) -> Earg(Aconst([|true|])) - | Ebinop(And, Aconst(VBit(false)), x) -> Earg(Aconst(VBit(false))) - | Ebinop(And, Aconst(VBit(true)), x) -> Earg(x) - | Ebinop(And, x, Aconst(VBit(false))) -> Earg(Aconst(VBit(false))) - | Ebinop(And, x, Aconst(VBit(true))) -> Earg(x) + | Ebinop(And, Aconst([|false|]), x) -> Earg(Aconst([|false|])) + | Ebinop(And, Aconst([|true|]), x) -> Earg(x) + | Ebinop(And, x, Aconst([|false|])) -> Earg(Aconst([|false|])) + | Ebinop(And, x, Aconst([|true|])) -> Earg(x) - | Ebinop(Xor, Aconst(VBit(false)), x) -> Earg(x) - | Ebinop(Xor, x, Aconst(VBit(false))) -> Earg(x) + | Ebinop(Xor, Aconst([|false|]), x) -> Earg(x) + | Ebinop(Xor, x, Aconst([|false|])) -> Earg(x) | Eslice(i, j, k) when i = j -> Eselect(i, k) | Econcat(Aconst(a), Aconst(b)) -> - let aa = match a with - | VBit(a) -> [| a |] - | VBitArray(a) -> a - in - let ba = match b with - | VBit(a) -> [| a |] - | VBitArray(a) -> a - in - Earg(Aconst(VBitArray(Array.append aa ba))) + Earg(Aconst(Array.append a b)) - | Eslice(i, j, Aconst(VBitArray(a))) -> - Earg(Aconst(VBitArray(Array.sub a i (j - i + 1)))) + | Eslice(i, j, Aconst(a)) -> + Earg(Aconst(Array.sub a i (j - i + 1))) - | Eselect(i, Aconst(VBitArray(a))) -> - Earg(Aconst(VBit(a.(i)))) + | Eselect(i, Aconst(a)) -> + Earg(Aconst([|a.(i)|])) | _ -> useless := true; eq in if not !useless then usefull := true; @@ -112,16 +104,19 @@ let arith_simplify p = (* if x is one bit, then : select 0 x = x + and same thing with select *) let select_to_id p = let usefull = ref false in { p_eqs = List.map (fun (n, eq) -> match eq with - | Eselect(0, Avar(id)) when - Env.find id p.p_vars = TBit || Env.find id p.p_vars = TBitArray(1) -> - usefull := true; - (n, Earg(Avar(id))) + | Eselect(0, Avar(id)) when Env.find id p.p_vars = 1 -> + usefull := true; + (n, Earg(Avar(id))) + | Eslice(0, sz, Avar(id)) when Env.find id p.p_vars = sz + 1 -> + usefull := true; + (n, Earg(Avar(id))) | _ -> (n, eq)) p.p_eqs; p_inputs = p.p_inputs; @@ -225,7 +220,36 @@ let rec eliminate_id p = (* Eliminate dead variables *) let eliminate_dead p = - (p, false) + let rec living basis = + let new_basis = List.fold_left + (fun b2 (n, eq) -> + if Sset.mem n b2 then + List.fold_left + (fun x k -> Sset.add k x) + b2 + (Scheduler.read_exp_all eq) + else + b2) + basis (List.rev p.p_eqs) + in + if Sset.cardinal new_basis > Sset.cardinal basis + then living new_basis + else new_basis + in + let outs = List.fold_left (fun x k -> Sset.add k x) Sset.empty p.p_outputs in + let ins = List.fold_left (fun x k -> Sset.add k x) Sset.empty p.p_inputs in + let live = living (Sset.union outs ins) in + { + p_eqs = List.filter (fun (n, _) -> Sset.mem n live) p.p_eqs; + p_inputs = p.p_inputs; + p_outputs = p.p_outputs; + p_vars = Env.fold + (fun k s newenv -> + if Sset.mem k live + then Env.add k s newenv + else newenv) + p.p_vars Env.empty + }, (Sset.cardinal live < Env.cardinal p.p_vars) (* Topological sort *) let topo_sort p = @@ -235,21 +259,28 @@ let topo_sort p = (* Apply all the simplification passes, in the order given in the header of this file *) -let rec simplify p = - let steps = [ - topo_sort, "topo_sort"; - cascade_slices, "cascade_slices"; - arith_simplify, "arith_simplify"; - select_to_id, "select_to_id"; - same_eq_simplify, "same_eq_simplify"; - eliminate_id, "eliminate_id"; - ] in +let rec simplify_with steps p = let pp, use = List.fold_left (fun (x, u) (f, n) -> print_string n; let xx, uu = f x in - print_string (if uu then "*\n" else "\n"); + print_string (if uu then " *\n" else "\n"); (xx, u || uu)) (p, false) steps in - if use then simplify pp else pp + if use then simplify_with steps pp else pp + +let simplify p = + let p = simplify_with [ topo_sort, "topo_sort" ] p in + let p = simplify_with [ + cascade_slices, "cascade_slices"; + arith_simplify, "arith_simplify"; + select_to_id, "select_to_id"; + same_eq_simplify, "same_eq_simplify"; + eliminate_id, "eliminate_id"; + ] p in + let p = simplify_with [ + eliminate_dead, "eliminate_dead"; + topo_sort, "topo_sort"; (* make sure last step is a topological sort *) + ] p in + p diff --git a/tests/nadder.mj b/tests/nadder.mj index c8b0fbe..91caa6d 100644 --- a/tests/nadder.mj +++ b/tests/nadder.mj @@ -1,3 +1,5 @@ +const word_size = 8 + fulladder(a,b,c) = (s, r) where s = (a ^ b) ^ c; r = (a & b) + ((a ^ b) & c); @@ -14,6 +16,6 @@ adder(a:[n], b:[n], c_in) = (o:[n], c_out) where end if end where -main(a:[8], b:[8]) = (o:[8], c) where - (o, c) = adder<8>(a,b,0) +main(a:[word_size], b:[word_size]) = (o:[word_size], c) where + (o, c) = adder(a,b,0) end where -- cgit v1.2.3