|
| 1 | +from assassyn.frontend import * |
| 2 | +from assassyn.test import run_test |
| 3 | +import random |
| 4 | +# Stage 3: full adder |
| 5 | +class FinalAdder(Module): |
| 6 | + def __init__(self): |
| 7 | + super().__init__(ports={ |
| 8 | + 'a': Port(Int(32)), |
| 9 | + 'b': Port(Int(32)), |
| 10 | + 'cnt': Port(Int(32)), |
| 11 | + 's': Port(Int(64)), |
| 12 | + 'carry': Port(Int(64)), |
| 13 | + } |
| 14 | + ) |
| 15 | + |
| 16 | + @module.combinational |
| 17 | + def build(self): |
| 18 | + a, b, cnt, s, carry = self.pop_all_ports(True) |
| 19 | + result = s + carry |
| 20 | + log("Final result {:?} * {:?} = {:?}", a, b, result) |
| 21 | + |
| 22 | + |
| 23 | +# MulStage 2: CSA + Pseudo-Wallace Tree |
| 24 | +class CSATree(Module): |
| 25 | + def __init__(self): |
| 26 | + super().__init__( |
| 27 | + ports={ |
| 28 | + 'a': Port(Int(32)), |
| 29 | + 'b': Port(Int(32)), |
| 30 | + 'cnt': Port(Int(32)), |
| 31 | + 'booth0': Port(Int(64)), |
| 32 | + 'booth1': Port(Int(64)), |
| 33 | + 'booth2': Port(Int(64)), |
| 34 | + 'booth3': Port(Int(64)), |
| 35 | + 'booth4': Port(Int(64)), |
| 36 | + 'booth5': Port(Int(64)), |
| 37 | + 'booth6': Port(Int(64)), |
| 38 | + 'booth7': Port(Int(64)), |
| 39 | + 'booth8': Port(Int(64)), |
| 40 | + 'booth9': Port(Int(64)), |
| 41 | + 'booth10': Port(Int(64)), |
| 42 | + 'booth11': Port(Int(64)), |
| 43 | + 'booth12': Port(Int(64)), |
| 44 | + 'booth13': Port(Int(64)), |
| 45 | + 'booth14': Port(Int(64)), |
| 46 | + 'booth15': Port(Int(64)), |
| 47 | + } |
| 48 | + ) |
| 49 | + |
| 50 | + @module.combinational |
| 51 | + def build(self, finaladder: FinalAdder): |
| 52 | + a, b, cnt, booth0, booth1, booth2, booth3, booth4, booth5, booth6, booth7, \ |
| 53 | + booth8, booth9, booth10, booth11, booth12, booth13, booth14, booth15 = self.pop_all_ports(True) |
| 54 | + |
| 55 | + def csa(x1, x2, x3): |
| 56 | + x1_b = x1.bitcast(Bits(64)) |
| 57 | + x2_b = x2.bitcast(Bits(64)) |
| 58 | + x3_b = x3.bitcast(Bits(64)) |
| 59 | + |
| 60 | + s_b = (x1_b ^ x2_b) ^ x3_b |
| 61 | + c_b = ((x1_b & x2_b) | (x2_b & x3_b) | (x3_b & x1_b)) << Bits(64)(1) |
| 62 | + |
| 63 | + return s_b.bitcast(Int(64)), c_b.bitcast(Int(64)) |
| 64 | + |
| 65 | + current_pps = [] |
| 66 | + current_pps.append(booth0) |
| 67 | + current_pps.append(booth1) |
| 68 | + current_pps.append(booth2) |
| 69 | + current_pps.append(booth3) |
| 70 | + current_pps.append(booth4) |
| 71 | + current_pps.append(booth5) |
| 72 | + current_pps.append(booth6) |
| 73 | + current_pps.append(booth7) |
| 74 | + current_pps.append(booth8) |
| 75 | + current_pps.append(booth9) |
| 76 | + current_pps.append(booth10) |
| 77 | + current_pps.append(booth11) |
| 78 | + current_pps.append(booth12) |
| 79 | + current_pps.append(booth13) |
| 80 | + current_pps.append(booth14) |
| 81 | + current_pps.append(booth15) |
| 82 | + |
| 83 | + while len(current_pps) > 2: |
| 84 | + |
| 85 | + next_pps = [] |
| 86 | + num_inputs = len(current_pps) |
| 87 | + i = 0 |
| 88 | + while (i + 2 < num_inputs): |
| 89 | + in1 = current_pps[i] |
| 90 | + in2 = current_pps[i + 1] |
| 91 | + in3 = current_pps[i + 2] |
| 92 | + s, c = csa(in1, in2, in3) |
| 93 | + next_pps.append(s) |
| 94 | + next_pps.append(c) |
| 95 | + i += 3 |
| 96 | + |
| 97 | + while i < num_inputs: |
| 98 | + next_pps.append(current_pps[i]) |
| 99 | + i += 1 |
| 100 | + |
| 101 | + current_pps = next_pps |
| 102 | + |
| 103 | + final_s = current_pps[0] |
| 104 | + final_carry = current_pps[1] |
| 105 | + log("CSATree: sum = {:?}, carry = {:?}",final_s,final_carry) |
| 106 | + finaladder.async_called(a=a, b=b, cnt=cnt, s=final_s, carry=final_carry) |
| 107 | + |
| 108 | +# MulStage 1: radix-4 booth encoding |
| 109 | +class BoothEncoder(Module): |
| 110 | + def __init__(self): |
| 111 | + super().__init__( |
| 112 | + ports={ |
| 113 | + 'a': Port(Int(32)), |
| 114 | + 'b': Port(Int(32)), |
| 115 | + 'cnt': Port(Int(32)), |
| 116 | + } |
| 117 | + ) |
| 118 | + |
| 119 | + @module.combinational |
| 120 | + def build(self, csatree: CSATree): |
| 121 | + a, b, cnt = self.pop_all_ports(True) |
| 122 | + |
| 123 | + b_unsigned = b.bitcast(Bits(32)) |
| 124 | + |
| 125 | + #calculate the complement of -2a |
| 126 | + a_comp = ((Int(32)(-1)*a) << Int(32)(1)).bitcast(Int(32)) |
| 127 | + |
| 128 | + #bit 0-5 |
| 129 | + b01 = Int(32)(0) |
| 130 | + b0 = (b_unsigned & Bits(32)(1)).bitcast(Int(32)) |
| 131 | + b1 = ((b_unsigned >> Bits(32)(1)) & Bits(32)(1)).bitcast(Int(32)) |
| 132 | + b2 = ((b_unsigned >> Bits(32)(2)) & Bits(32)(1)).bitcast(Int(32)) |
| 133 | + b3 = ((b_unsigned >> Bits(32)(3)) & Bits(32)(1)).bitcast(Int(32)) |
| 134 | + b4 = ((b_unsigned >> Bits(32)(4)) & Bits(32)(1)).bitcast(Int(32)) |
| 135 | + b5 = ((b_unsigned >> Bits(32)(5)) & Bits(32)(1)).bitcast(Int(32)) |
| 136 | + booth0 = ((b01 + b0) * a + b1 * a_comp).bitcast(Int(64)) |
| 137 | + booth1 = (((b1 + b2) * a + b3 * a_comp) << Int(32)(2)).bitcast(Int(64)) |
| 138 | + booth2 = (((b3 + b4) * a + b5 * a_comp) << Int(32)(4)).bitcast(Int(64)) |
| 139 | + |
| 140 | + #bit 6-11 |
| 141 | + b_shift_unsigned = ((b_unsigned >> Bits(32)(6)) & Bits(32)(31)) # |
| 142 | + b0 = (b_shift_unsigned & Bits(32)(1)).bitcast(Int(32)) |
| 143 | + b1 = ((b_shift_unsigned >> Bits(32)(1)) & Bits(32)(1)).bitcast(Int(32)) |
| 144 | + b2 = ((b_shift_unsigned >> Bits(32)(2)) & Bits(32)(1)).bitcast(Int(32)) |
| 145 | + b3 = ((b_shift_unsigned >> Bits(32)(3)) & Bits(32)(1)).bitcast(Int(32)) |
| 146 | + b4 = ((b_shift_unsigned >> Bits(32)(4)) & Bits(32)(1)).bitcast(Int(32)) |
| 147 | + booth3 = (((b5 + b0) * a + b1 * a_comp) << Int(32)(6)).bitcast(Int(64)) |
| 148 | + booth4 = (((b1 + b2) * a + b3 * a_comp) << Int(32)(8)).bitcast(Int(64)) |
| 149 | + b5 = ((b_shift_unsigned >> Bits(32)(5)) & Bits(32)(1)).bitcast(Int(32)) |
| 150 | + booth5 = (((b3 + b4) * a + b5 * a_comp) << Int(32)(10)).bitcast(Int(64)) |
| 151 | + |
| 152 | + #bit 12-17 |
| 153 | + b_shift_unsigned = ((b_shift_unsigned >> Bits(32)(6)) & Bits(32)(31)) |
| 154 | + b0 = (b_shift_unsigned & Bits(32)(1)).bitcast(Int(32)) |
| 155 | + b1 = ((b_shift_unsigned >> Bits(32)(1)) & Bits(32)(1)).bitcast(Int(32)) |
| 156 | + b2 = ((b_shift_unsigned >> Bits(32)(2)) & Bits(32)(1)).bitcast(Int(32)) |
| 157 | + b3 = ((b_shift_unsigned >> Bits(32)(3)) & Bits(32)(1)).bitcast(Int(32)) |
| 158 | + b4 = ((b_shift_unsigned >> Bits(32)(4)) & Bits(32)(1)).bitcast(Int(32)) |
| 159 | + booth6 = (((b5 + b0) * a + b1 * a_comp) << Int(32)(12)).bitcast(Int(64)) |
| 160 | + booth7 = (((b1 + b2) * a + b3 * a_comp) << Int(32)(14)).bitcast(Int(64)) |
| 161 | + b5 = ((b_shift_unsigned >> Bits(32)(5)) & Bits(32)(1)).bitcast(Int(32)) |
| 162 | + booth8 = (((b3 + b4) * a + b5 * a_comp) << Int(32)(16)).bitcast(Int(64)) |
| 163 | + |
| 164 | + #bit 18-23 |
| 165 | + b_shift_unsigned = ((b_shift_unsigned >> Bits(32)(6)) & Bits(32)(31)) |
| 166 | + b0 = (b_shift_unsigned & Bits(32)(1)).bitcast(Int(32)) |
| 167 | + b1 = ((b_shift_unsigned >> Bits(32)(1)) & Bits(32)(1)).bitcast(Int(32)) |
| 168 | + b2 = ((b_shift_unsigned >> Bits(32)(2)) & Bits(32)(1)).bitcast(Int(32)) |
| 169 | + b3 = ((b_shift_unsigned >> Bits(32)(3)) & Bits(32)(1)).bitcast(Int(32)) |
| 170 | + b4 = ((b_shift_unsigned >> Bits(32)(4)) & Bits(32)(1)).bitcast(Int(32)) |
| 171 | + booth9 = (((b5 + b0) * a + b1 * a_comp) << Int(32)(18)).bitcast(Int(64)) |
| 172 | + booth10 = (((b1 + b2) * a + b3 * a_comp) << Int(32)(20)).bitcast(Int(64)) |
| 173 | + b5 = ((b_shift_unsigned >> Bits(32)(5)) & Bits(32)(1)).bitcast(Int(32)) |
| 174 | + booth11 = (((b3 + b4) * a + b5 * a_comp) << Int(32)(22)).bitcast(Int(64)) |
| 175 | + |
| 176 | + #bit 24-29 |
| 177 | + b_shift_unsigned = ((b_shift_unsigned >> Bits(32)(6)) & Bits(32)(31)) |
| 178 | + b0 = (b_shift_unsigned & Bits(32)(1)).bitcast(Int(32)) |
| 179 | + b1 = ((b_shift_unsigned >> Bits(32)(1)) & Bits(32)(1)).bitcast(Int(32)) |
| 180 | + b2 = ((b_shift_unsigned >> Bits(32)(2)) & Bits(32)(1)).bitcast(Int(32)) |
| 181 | + b3 = ((b_shift_unsigned >> Bits(32)(3)) & Bits(32)(1)).bitcast(Int(32)) |
| 182 | + b4 = ((b_shift_unsigned >> Bits(32)(4)) & Bits(32)(1)).bitcast(Int(32)) |
| 183 | + booth12 = (((b5 + b0) * a + b1 * a_comp) << Int(32)(24)).bitcast(Int(64)) |
| 184 | + booth13 = (((b1 + b2) * a + b3 * a_comp) << Int(32)(26)).bitcast(Int(64)) |
| 185 | + b5 = ((b_shift_unsigned >> Bits(32)(5)) & Bits(32)(1)).bitcast(Int(32)) |
| 186 | + booth14 = (((b3 + b4) * a + b5 * a_comp) << Int(32)(28)).bitcast(Int(64)) |
| 187 | + |
| 188 | + #bit 30-31 |
| 189 | + b_shift_unsigned = ((b_shift_unsigned >> Bits(32)(6)) & Bits(32)(31)) |
| 190 | + b0 = (b_shift_unsigned & Bits(32)(1)).bitcast(Int(32)) |
| 191 | + b1 = ((b_shift_unsigned >> Bits(32)(1)) & Bits(32)(1)).bitcast(Int(32)) |
| 192 | + booth15 = (((b5 + b0) * a + b1 * a_comp) << Int(32)(30)).bitcast(Int(64)) |
| 193 | + |
| 194 | + log("BoothEncoder: DONE booth coding for {:?} * {:?}", a, b) |
| 195 | + csatree.async_called(a=a, b=b, cnt=cnt, booth0=booth0, booth1=booth1, booth2=booth2, booth3=booth3, booth4=booth4, booth5=booth5, |
| 196 | + booth6=booth6, booth7=booth7, booth8=booth8, booth9=booth9, booth10=booth10, booth11=booth11, booth12=booth12, booth13=booth13, |
| 197 | + booth14=booth14, booth15=booth15,) |
| 198 | + |
| 199 | +class Driver(Module): |
| 200 | + def __init__(self): |
| 201 | + super().__init__(ports={}) |
| 202 | + |
| 203 | + @module.combinational |
| 204 | + def build(self, boothencoder: BoothEncoder): |
| 205 | + cnt = RegArray(Int(32), 1) |
| 206 | + (cnt & self)[0] <= cnt[0] + Int(32)(1) |
| 207 | + cond = cnt[0] < Int(32)(95) |
| 208 | + # test input from 0 to 94 |
| 209 | + input_a = RegArray(Int(32),1) |
| 210 | + input_b = RegArray(Int(32),1) |
| 211 | + (input_a & self)[0] <= input_a[0] + Int(32)(1) |
| 212 | + (input_b & self)[0] <= input_b[0] + Int(32)(1) |
| 213 | + with Condition(cond): |
| 214 | + boothencoder.async_called(a=input_a[0], b=input_b[0], cnt=cnt[0]) |
| 215 | + |
| 216 | +def build_system(): |
| 217 | + |
| 218 | + finaladder = FinalAdder() |
| 219 | + finaladder.build() |
| 220 | + csatree = CSATree() |
| 221 | + csatree.build(finaladder) |
| 222 | + boothencoder = BoothEncoder() |
| 223 | + boothencoder.build(csatree) |
| 224 | + driver = Driver() |
| 225 | + driver.build(boothencoder) |
| 226 | + |
| 227 | +def check_raw(raw): |
| 228 | + cnt = 0 |
| 229 | + for i in raw.split('\n'): |
| 230 | + if 'Final result' in i: |
| 231 | + line_toks = i.split() |
| 232 | + c = line_toks[-1] |
| 233 | + b = line_toks[-3] |
| 234 | + a = line_toks[-5] |
| 235 | + assert int(a) * int(b) == int(c) |
| 236 | + cnt += 1 |
| 237 | + |
| 238 | +def test_multiplier(): |
| 239 | + run_test('multiplier_test', build_system, check_raw) |
| 240 | + |
| 241 | +if __name__ == '__main__': |
| 242 | + test_multiplier() |
0 commit comments