From 07b7563e0748b1aff6f4d28b0172095b2fdcdfcc Mon Sep 17 00:00:00 2001 From: Alex AUVOLAT Date: Tue, 5 Nov 2013 13:47:12 +0100 Subject: Added netlist simplification passes (not yet quite complete !) --- README | 16 ++++ csim/load.c | 6 +- csim/sim.c | 16 ++-- sched/main.ml | 2 +- sched/netlist_printer.ml | 4 +- sched/simplify.ml | 214 +++++++++++++++++++++++++++++++++++++++++++++++ tests/clockHMS.mj | 69 +++++++++++++++ tests/nadder.mj | 12 +-- 8 files changed, 322 insertions(+), 17 deletions(-) create mode 100644 sched/simplify.ml create mode 100644 tests/clockHMS.mj diff --git a/README b/README index ff55a2c..40734fa 100644 --- a/README +++ b/README @@ -6,6 +6,7 @@ Alex AUVOLAT (Info 2013) Contents of the repository : +---------------------------- sched/ A scheduler for netlists. @@ -36,3 +37,18 @@ minijazz/ Documentation about the project. + +CONVENTION FOR BINARY VALUES +---------------------------- + +/!\ This convention is contrary to the one used in the example file nadder.mj + (Therefore I have modified that file...) + +The bit array [a_0 a_1 a_2 ... a_n-1] represents the decimal number : + a_0 + 2*a_1 + 4*a_2 + ... + 2^(n-1)*a_n-1 + +When represented in binary, we write the bits in the order : + a_0 a_1 a_2 ... a_n-1 + +/!\ BINARY NUMBERS ARE WRITTEN REVERSE ! + diff --git a/csim/load.c b/csim/load.c index a7e6cac..7971a71 100644 --- a/csim/load.c +++ b/csim/load.c @@ -35,10 +35,10 @@ t_value read_bool(FILE *stream, t_value *mask) { void read_arg(FILE *stream, t_arg *dest) { dest->mask = 0; - if (fscanf(stream, "$") > 0) { - dest->Val = read_bool(stream, &dest->mask); + if (fscanf(stream, "$%d ", &(dest->SrcVar))) { + // ok, value is read } else { - fscanf(stream, "%d ", &(dest->SrcVar)); + dest->Val = read_bool(stream, &dest->mask); } } diff --git a/csim/sim.c b/csim/sim.c index 9b6906f..db3b711 100644 --- a/csim/sim.c +++ b/csim/sim.c @@ -55,7 +55,8 @@ t_machine *init_machine (t_program *p) { void read_inputs(t_machine *m, FILE *stream) { /* FORMAT : For each input in the list, *in the order specified*, - the binary value for that variable. + either '/' followed by the decimal value + or the binary value */ int i; t_id var; @@ -66,7 +67,12 @@ void read_inputs(t_machine *m, FILE *stream) { for (i = 0; i < p->n_inputs; i++) { var = p->inputs[i]; fscanf(stream, " "); - m->var_values[var] = read_bool(stream, NULL); + if (fscanf(stream, "/%lu", &(m->var_values[var]))) { + // ok, value is read + } else { + m->var_values[var] = read_bool(stream, NULL); + } + m->var_values[var] &= p->vars[var].mask; } } @@ -179,7 +185,7 @@ void machine_step(t_machine *m) { if (e != 0) { a = get_var(m, p->eqs[i].Ram.write_addr); d = get_var(m, p->eqs[i].Ram.data); - printf("Write ram %lx = %lx\n", a, d); + if (DEBUG) fprintf(stderr, "Write ram %lx = %lx\n", a, d); m->mem_data[i].RamData[a] = d; } } @@ -189,7 +195,7 @@ void machine_step(t_machine *m) { void write_outputs(t_machine *m, FILE *stream) { /* FORMAT : For each output value, a line in the form - var_name binary_value + var_name binary_value decimal_value */ int i; t_id var; @@ -205,7 +211,7 @@ void write_outputs(t_machine *m, FILE *stream) { v >>= 1; mask >>= 1; } - fprintf(stream, "\n"); + fprintf(stream, "\t%ld\n", m->var_values[var]); } fprintf(stream, "\n"); } diff --git a/sched/main.ml b/sched/main.ml index 988d1ec..a2a4d3b 100644 --- a/sched/main.ml +++ b/sched/main.ml @@ -11,7 +11,7 @@ let compile filename = let q = ref p in begin try - q := Scheduler.schedule p + q := (Simplify.simplify (Scheduler.schedule p)) with | Scheduler.Combinational_cycle -> Format.eprintf "The netlist has a combinatory cycle.@."; diff --git a/sched/netlist_printer.ml b/sched/netlist_printer.ml index b8cf385..746867f 100644 --- a/sched/netlist_printer.ml +++ b/sched/netlist_printer.ml @@ -133,8 +133,8 @@ let print_dumb_program oc p = fprintf ff "%d\n" (List.length p.p_eqs); (* write equations *) let print_arg = function - | Avar(k) -> fprintf ff " %d" (Hashtbl.find var_id k) - | Aconst(n) -> fprintf ff " $"; + | Avar(k) -> fprintf ff " $%d" (Hashtbl.find var_id k) + | Aconst(n) -> fprintf ff " "; begin match n with | VBit(x) -> fprintf ff "%d" (if x then 1 else 0) | VBitArray(a) -> diff --git a/sched/simplify.ml b/sched/simplify.ml new file mode 100644 index 0000000..01f7c84 --- /dev/null +++ b/sched/simplify.ml @@ -0,0 +1,214 @@ +(* SIMPLIFICATION PASSES *) + +(* + Order of simplifications : + - cascade slices and selects + - simplify stupid things (a xor 0 = a, a and 0 = 0, etc.) + transform k = SLICE i i var into k = SELECT i var + - transform k = SELECT 0 var into k = var when var is also one bit + - look for variables with same equation, put the second to identity + - eliminate k' for each equation k' = k + - eliminate dead equations + + These simplifications are run on a topologically sorted list of equations (see main.ml) +*) + +open Netlist_ast + +module Sset = Set.Make(String) + +(* Simplify cascade slicing/selecting *) +let cascade_slices p = + let slices = Hashtbl.create 42 in + let eqs_new = List.map + (fun (n, eq) -> (n, match eq with + | Eslice(u, v, Avar(x)) -> + let nu, nx = + if Hashtbl.mem slices x then begin + let ku, kx = Hashtbl.find slices x in + (ku + u, kx) + end else + (u, x) + in + Hashtbl.add slices n (nu, nx); + Eslice(nu, v, Avar(nx)) + | Eselect(u, Avar(x)) -> + begin try + let ku, kx = Hashtbl.find slices x in + Eselect(ku + u, Avar(kx)) + with + Not_found -> Eselect(u, Avar(x)) + end + | _ -> eq)) + p.p_eqs in + { + p_eqs = eqs_new; + p_inputs = p.p_inputs; + p_outputs = p.p_outputs; + p_vars = p.p_vars; + } + +(* Simplifies some trivial arithmetic possibilites : + a and 1 = a + a and 0 = 0 + a or 1 = 1 + a or 0 = a + a xor 0 = a + slice i i x = select i x +*) +let arith_simplify p = + { + p_eqs = List.map + (fun (n, eq) -> match eq with + | Ebinop(Or, Aconst(VBit(false)), x) -> (n, Earg(x)) + | Ebinop(Or, Aconst(VBit(true)), x) -> (n, Earg(Aconst(VBit(true)))) + | Ebinop(Or, x, Aconst(VBit(false))) -> (n, Earg(x)) + | Ebinop(Or, x, Aconst(VBit(true))) -> (n, Earg(Aconst(VBit(true)))) + + | Ebinop(And, Aconst(VBit(false)), x) -> (n, Earg(Aconst(VBit(false)))) + | Ebinop(And, Aconst(VBit(true)), x) -> (n, Earg(x)) + | Ebinop(And, x, Aconst(VBit(false))) -> (n, Earg(Aconst(VBit(false)))) + | Ebinop(And, x, Aconst(VBit(true))) -> (n, Earg(x)) + + | Ebinop(Xor, Aconst(VBit(false)), x) -> (n, Earg(x)) + | Ebinop(Xor, x, Aconst(VBit(false))) -> (n, Earg(x)) + + | Eslice(i, j, k) when i = j -> + (n, Eselect(i, k)) + + | _ -> (n, eq)) + p.p_eqs; + p_inputs = p.p_inputs; + p_outputs = p.p_outputs; + p_vars = p.p_vars; + } + +(* if x is one bit, then : + select 0 x = x +*) +let select_to_id p = + { + p_eqs = List.map + (fun (n, eq) -> match eq with + | Eselect(0, Avar(id)) when + Env.find id p.p_vars = TBit || Env.find id p.p_vars = TBitArray(1) -> + (n, Earg(Avar(id))) + | _ -> (n, eq)) + p.p_eqs; + p_inputs = p.p_inputs; + p_outputs = p.p_outputs; + p_vars = p.p_vars; + } + +(* + If a = eqn(v1, v2, ...) and b = eqn(v1, v2, ...) <- the same equation + then say b = a +*) +let same_eq_simplify p = + let id_outputs = + (List.fold_left (fun x k -> Sset.add k x) Sset.empty p.p_outputs) in + let eq_map = Hashtbl.create 42 in + List.iter + (fun (n, eq) -> if Sset.mem n id_outputs then + Hashtbl.add eq_map eq n) + p.p_eqs; + let simplify_eq (n, eq) = + if Sset.mem n id_outputs then + (n, eq) + else if Hashtbl.mem eq_map eq then + (n, Earg(Avar(Hashtbl.find eq_map eq))) + else begin + Hashtbl.add eq_map eq n; + (n, eq) + end + in + let eq2 = List.map simplify_eq p.p_eqs in + { + p_eqs = eq2; + p_inputs = p.p_inputs; + p_outputs = p.p_outputs; + p_vars = p.p_vars; + } + + +(* Replace one specific variable by another argument in the arguments of all equations + (possibly a constant, possibly another variable) +*) +let eliminate_var var rep p = + let rep_arg = function + | Avar(i) when i = var -> rep + | k -> k + in + let rep_eqs = List.map + (fun (n, eq) -> (n, match eq with + | Earg(a) -> Earg(rep_arg a) + | Ereg(i) when i = var -> + begin match rep with + | Avar(j) -> Ereg(j) + | Aconst(k) -> Earg(Aconst(k)) + end + | Ereg(j) -> Ereg(j) + | Enot(a) -> Enot(rep_arg a) + | Ebinop(o, a, b) -> Ebinop(o, rep_arg a, rep_arg b) + | Emux(a, b, c) -> Emux(rep_arg a, rep_arg b, rep_arg c) + | Erom(u, v, a) -> Erom(u, v, rep_arg a) + | Eram(u, v, a, b, c, d) -> Eram(u, v, rep_arg a, rep_arg b, rep_arg c, rep_arg d) + | Econcat(a, b) -> Econcat(rep_arg a, rep_arg b) + | Eslice(u, v, a) -> Eslice(u, v, rep_arg a) + | Eselect(u, a) -> Eselect(u, rep_arg a) + )) + p.p_eqs in + { + p_eqs = List.fold_left + (fun x (n, eq) -> + if n = var then x else (n, eq)::x) + [] rep_eqs; + p_inputs = p.p_inputs; + p_outputs = p.p_outputs; + p_vars = Env.remove var p.p_vars; + } + +(* Remove all equations of type : + a = b + a = const + (except if a is an output variable) +*) +let rec eliminate_id p = + let id_outputs = + (List.fold_left (fun x k -> Sset.add k x) Sset.empty p.p_outputs) in + + let rep = + List.fold_left + (fun x (n, eq) -> + if x = None && (not (Sset.mem n id_outputs)) then + match eq with + | Earg(rarg) -> + Some(n, rarg) + | _ -> None + else + x) + None p.p_eqs in + match rep with + | None -> p, false + | Some(n, rep) -> fst (eliminate_id (eliminate_var n rep p)), true + + +(* Eliminate dead equations *) +let eliminate_dead p = + p, false (* TODO *) + (* a bit like a topological sort... *) + + +(* Apply all the simplification passes, + in the order given in the header of this file +*) +let rec simplify p = + let p1 = cascade_slices p in + let p2 = arith_simplify p1 in + let p3 = select_to_id p2 in + let p4 = same_eq_simplify p3 in + let p5, use5 = eliminate_id p4 in + let p6, use6 = eliminate_dead p5 in + let pp = p6 in + if use5 || use6 then simplify pp else pp + diff --git a/tests/clockHMS.mj b/tests/clockHMS.mj new file mode 100644 index 0000000..27f96ea --- /dev/null +++ b/tests/clockHMS.mj @@ -0,0 +1,69 @@ +repeat(a) = (x:[n]) where + if n = 1 then + x = a + else + if n - (2 * (n / 2)) = 1 then + u = repeat(a); + x = a . u . u + else + u = repeat(a); + x = u . u + end if + end if +end where + +fulladder(a,b,c) = (s, r) where + s = (a ^ b) ^ c; + r = (a & b) + ((a ^ b) & c); +end where + +adder(a:[n], b:[n], c_in) = (o:[n], c_out) where + if n = 0 then + o = []; + c_out = 0 + else + (s_n, c_n1) = fulladder(a[0], b[0], c_in); + (s_n1, c_out) = adder(a[1..], b[1..], c_n1); + o = s_n . s_n1 + end if +end where + +equal(a:[n]) = (eq) where + if n = 0 then + eq = 1 + else + if m - (2 * (m / 2)) = 1 then + eq = a[0] & equal(a[1..]); + else + eq = (not a[0]) & equal(a[1..]); + end if + end if +end where + +reg_n(a:[n]) = (r:[n]) where + if n = 0 then + r = [] + else + r = (reg a[0]) . (reg_n(r[1..])) + end if +end where + +and_each(a, b:[n]) = (o:[n]) where + if n = 0 then + o = [] + else + o = (b[0] and a) . and_each(a, b[1..]) + end if +end where + +count_mod(in:[n]) = (out:[n]) where + neq = not (equal(in)); + (incr, carry) = adder(in, 1 . repeat(0), 0); + out = and_each(neq, incr) +end where + +main() = (ret:[2],out:[2]) where + out = count_mod<2, 3>(ret); + ret = reg_n<2>(out) +end where + diff --git a/tests/nadder.mj b/tests/nadder.mj index 0c95386..b75a83c 100644 --- a/tests/nadder.mj +++ b/tests/nadder.mj @@ -6,14 +6,14 @@ end where adder(a:[n], b:[n], c_in) = (o:[n], c_out) where if n = 0 then o = []; - c_out = 0 + c_out = c_in else - (s_n1, c_n1) = adder(a[1..], b[1..], c_in); - (s_n, c_out) = fulladder(a[0], b[0], c_n1); + (s_n, c_n1) = fulladder(a[0], b[0], c_in); + (s_n1, c_out) = adder(a[1..], b[1..], c_n1); o = s_n . s_n1 end if end where -main(a, b) = (o, c) where - (o, c) = adder<1>(a,b,0) -end where \ No newline at end of file +main(a:[4], b:[4]) = (o:[4], c) where + (o, c) = adder<4>(a,b,0) +end where -- cgit v1.2.3