Skip to content

Commit d179939

Browse files
authored
Add a multiplier using CSA and booth coding (#418)
1 parent 00a21d5 commit d179939

File tree

1 file changed

+242
-0
lines changed

1 file changed

+242
-0
lines changed

python/ci-tests/test_csamul.py

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
from assassyn.frontend import *
2+
from assassyn.test import run_test
3+
import random
4+
# Stage 3: full adder
5+
class FinalAdder(Module):
6+
def __init__(self):
7+
super().__init__(ports={
8+
'a': Port(Int(32)),
9+
'b': Port(Int(32)),
10+
'cnt': Port(Int(32)),
11+
's': Port(Int(64)),
12+
'carry': Port(Int(64)),
13+
}
14+
)
15+
16+
@module.combinational
17+
def build(self):
18+
a, b, cnt, s, carry = self.pop_all_ports(True)
19+
result = s + carry
20+
log("Final result {:?} * {:?} = {:?}", a, b, result)
21+
22+
23+
# MulStage 2: CSA + Pseudo-Wallace Tree
24+
class CSATree(Module):
25+
def __init__(self):
26+
super().__init__(
27+
ports={
28+
'a': Port(Int(32)),
29+
'b': Port(Int(32)),
30+
'cnt': Port(Int(32)),
31+
'booth0': Port(Int(64)),
32+
'booth1': Port(Int(64)),
33+
'booth2': Port(Int(64)),
34+
'booth3': Port(Int(64)),
35+
'booth4': Port(Int(64)),
36+
'booth5': Port(Int(64)),
37+
'booth6': Port(Int(64)),
38+
'booth7': Port(Int(64)),
39+
'booth8': Port(Int(64)),
40+
'booth9': Port(Int(64)),
41+
'booth10': Port(Int(64)),
42+
'booth11': Port(Int(64)),
43+
'booth12': Port(Int(64)),
44+
'booth13': Port(Int(64)),
45+
'booth14': Port(Int(64)),
46+
'booth15': Port(Int(64)),
47+
}
48+
)
49+
50+
@module.combinational
51+
def build(self, finaladder: FinalAdder):
52+
a, b, cnt, booth0, booth1, booth2, booth3, booth4, booth5, booth6, booth7, \
53+
booth8, booth9, booth10, booth11, booth12, booth13, booth14, booth15 = self.pop_all_ports(True)
54+
55+
def csa(x1, x2, x3):
56+
x1_b = x1.bitcast(Bits(64))
57+
x2_b = x2.bitcast(Bits(64))
58+
x3_b = x3.bitcast(Bits(64))
59+
60+
s_b = (x1_b ^ x2_b) ^ x3_b
61+
c_b = ((x1_b & x2_b) | (x2_b & x3_b) | (x3_b & x1_b)) << Bits(64)(1)
62+
63+
return s_b.bitcast(Int(64)), c_b.bitcast(Int(64))
64+
65+
current_pps = []
66+
current_pps.append(booth0)
67+
current_pps.append(booth1)
68+
current_pps.append(booth2)
69+
current_pps.append(booth3)
70+
current_pps.append(booth4)
71+
current_pps.append(booth5)
72+
current_pps.append(booth6)
73+
current_pps.append(booth7)
74+
current_pps.append(booth8)
75+
current_pps.append(booth9)
76+
current_pps.append(booth10)
77+
current_pps.append(booth11)
78+
current_pps.append(booth12)
79+
current_pps.append(booth13)
80+
current_pps.append(booth14)
81+
current_pps.append(booth15)
82+
83+
while len(current_pps) > 2:
84+
85+
next_pps = []
86+
num_inputs = len(current_pps)
87+
i = 0
88+
while (i + 2 < num_inputs):
89+
in1 = current_pps[i]
90+
in2 = current_pps[i + 1]
91+
in3 = current_pps[i + 2]
92+
s, c = csa(in1, in2, in3)
93+
next_pps.append(s)
94+
next_pps.append(c)
95+
i += 3
96+
97+
while i < num_inputs:
98+
next_pps.append(current_pps[i])
99+
i += 1
100+
101+
current_pps = next_pps
102+
103+
final_s = current_pps[0]
104+
final_carry = current_pps[1]
105+
log("CSATree: sum = {:?}, carry = {:?}",final_s,final_carry)
106+
finaladder.async_called(a=a, b=b, cnt=cnt, s=final_s, carry=final_carry)
107+
108+
# MulStage 1: radix-4 booth encoding
109+
class BoothEncoder(Module):
110+
def __init__(self):
111+
super().__init__(
112+
ports={
113+
'a': Port(Int(32)),
114+
'b': Port(Int(32)),
115+
'cnt': Port(Int(32)),
116+
}
117+
)
118+
119+
@module.combinational
120+
def build(self, csatree: CSATree):
121+
a, b, cnt = self.pop_all_ports(True)
122+
123+
b_unsigned = b.bitcast(Bits(32))
124+
125+
#calculate the complement of -2a
126+
a_comp = ((Int(32)(-1)*a) << Int(32)(1)).bitcast(Int(32))
127+
128+
#bit 0-5
129+
b01 = Int(32)(0)
130+
b0 = (b_unsigned & Bits(32)(1)).bitcast(Int(32))
131+
b1 = ((b_unsigned >> Bits(32)(1)) & Bits(32)(1)).bitcast(Int(32))
132+
b2 = ((b_unsigned >> Bits(32)(2)) & Bits(32)(1)).bitcast(Int(32))
133+
b3 = ((b_unsigned >> Bits(32)(3)) & Bits(32)(1)).bitcast(Int(32))
134+
b4 = ((b_unsigned >> Bits(32)(4)) & Bits(32)(1)).bitcast(Int(32))
135+
b5 = ((b_unsigned >> Bits(32)(5)) & Bits(32)(1)).bitcast(Int(32))
136+
booth0 = ((b01 + b0) * a + b1 * a_comp).bitcast(Int(64))
137+
booth1 = (((b1 + b2) * a + b3 * a_comp) << Int(32)(2)).bitcast(Int(64))
138+
booth2 = (((b3 + b4) * a + b5 * a_comp) << Int(32)(4)).bitcast(Int(64))
139+
140+
#bit 6-11
141+
b_shift_unsigned = ((b_unsigned >> Bits(32)(6)) & Bits(32)(31)) #
142+
b0 = (b_shift_unsigned & Bits(32)(1)).bitcast(Int(32))
143+
b1 = ((b_shift_unsigned >> Bits(32)(1)) & Bits(32)(1)).bitcast(Int(32))
144+
b2 = ((b_shift_unsigned >> Bits(32)(2)) & Bits(32)(1)).bitcast(Int(32))
145+
b3 = ((b_shift_unsigned >> Bits(32)(3)) & Bits(32)(1)).bitcast(Int(32))
146+
b4 = ((b_shift_unsigned >> Bits(32)(4)) & Bits(32)(1)).bitcast(Int(32))
147+
booth3 = (((b5 + b0) * a + b1 * a_comp) << Int(32)(6)).bitcast(Int(64))
148+
booth4 = (((b1 + b2) * a + b3 * a_comp) << Int(32)(8)).bitcast(Int(64))
149+
b5 = ((b_shift_unsigned >> Bits(32)(5)) & Bits(32)(1)).bitcast(Int(32))
150+
booth5 = (((b3 + b4) * a + b5 * a_comp) << Int(32)(10)).bitcast(Int(64))
151+
152+
#bit 12-17
153+
b_shift_unsigned = ((b_shift_unsigned >> Bits(32)(6)) & Bits(32)(31))
154+
b0 = (b_shift_unsigned & Bits(32)(1)).bitcast(Int(32))
155+
b1 = ((b_shift_unsigned >> Bits(32)(1)) & Bits(32)(1)).bitcast(Int(32))
156+
b2 = ((b_shift_unsigned >> Bits(32)(2)) & Bits(32)(1)).bitcast(Int(32))
157+
b3 = ((b_shift_unsigned >> Bits(32)(3)) & Bits(32)(1)).bitcast(Int(32))
158+
b4 = ((b_shift_unsigned >> Bits(32)(4)) & Bits(32)(1)).bitcast(Int(32))
159+
booth6 = (((b5 + b0) * a + b1 * a_comp) << Int(32)(12)).bitcast(Int(64))
160+
booth7 = (((b1 + b2) * a + b3 * a_comp) << Int(32)(14)).bitcast(Int(64))
161+
b5 = ((b_shift_unsigned >> Bits(32)(5)) & Bits(32)(1)).bitcast(Int(32))
162+
booth8 = (((b3 + b4) * a + b5 * a_comp) << Int(32)(16)).bitcast(Int(64))
163+
164+
#bit 18-23
165+
b_shift_unsigned = ((b_shift_unsigned >> Bits(32)(6)) & Bits(32)(31))
166+
b0 = (b_shift_unsigned & Bits(32)(1)).bitcast(Int(32))
167+
b1 = ((b_shift_unsigned >> Bits(32)(1)) & Bits(32)(1)).bitcast(Int(32))
168+
b2 = ((b_shift_unsigned >> Bits(32)(2)) & Bits(32)(1)).bitcast(Int(32))
169+
b3 = ((b_shift_unsigned >> Bits(32)(3)) & Bits(32)(1)).bitcast(Int(32))
170+
b4 = ((b_shift_unsigned >> Bits(32)(4)) & Bits(32)(1)).bitcast(Int(32))
171+
booth9 = (((b5 + b0) * a + b1 * a_comp) << Int(32)(18)).bitcast(Int(64))
172+
booth10 = (((b1 + b2) * a + b3 * a_comp) << Int(32)(20)).bitcast(Int(64))
173+
b5 = ((b_shift_unsigned >> Bits(32)(5)) & Bits(32)(1)).bitcast(Int(32))
174+
booth11 = (((b3 + b4) * a + b5 * a_comp) << Int(32)(22)).bitcast(Int(64))
175+
176+
#bit 24-29
177+
b_shift_unsigned = ((b_shift_unsigned >> Bits(32)(6)) & Bits(32)(31))
178+
b0 = (b_shift_unsigned & Bits(32)(1)).bitcast(Int(32))
179+
b1 = ((b_shift_unsigned >> Bits(32)(1)) & Bits(32)(1)).bitcast(Int(32))
180+
b2 = ((b_shift_unsigned >> Bits(32)(2)) & Bits(32)(1)).bitcast(Int(32))
181+
b3 = ((b_shift_unsigned >> Bits(32)(3)) & Bits(32)(1)).bitcast(Int(32))
182+
b4 = ((b_shift_unsigned >> Bits(32)(4)) & Bits(32)(1)).bitcast(Int(32))
183+
booth12 = (((b5 + b0) * a + b1 * a_comp) << Int(32)(24)).bitcast(Int(64))
184+
booth13 = (((b1 + b2) * a + b3 * a_comp) << Int(32)(26)).bitcast(Int(64))
185+
b5 = ((b_shift_unsigned >> Bits(32)(5)) & Bits(32)(1)).bitcast(Int(32))
186+
booth14 = (((b3 + b4) * a + b5 * a_comp) << Int(32)(28)).bitcast(Int(64))
187+
188+
#bit 30-31
189+
b_shift_unsigned = ((b_shift_unsigned >> Bits(32)(6)) & Bits(32)(31))
190+
b0 = (b_shift_unsigned & Bits(32)(1)).bitcast(Int(32))
191+
b1 = ((b_shift_unsigned >> Bits(32)(1)) & Bits(32)(1)).bitcast(Int(32))
192+
booth15 = (((b5 + b0) * a + b1 * a_comp) << Int(32)(30)).bitcast(Int(64))
193+
194+
log("BoothEncoder: DONE booth coding for {:?} * {:?}", a, b)
195+
csatree.async_called(a=a, b=b, cnt=cnt, booth0=booth0, booth1=booth1, booth2=booth2, booth3=booth3, booth4=booth4, booth5=booth5,
196+
booth6=booth6, booth7=booth7, booth8=booth8, booth9=booth9, booth10=booth10, booth11=booth11, booth12=booth12, booth13=booth13,
197+
booth14=booth14, booth15=booth15,)
198+
199+
class Driver(Module):
200+
def __init__(self):
201+
super().__init__(ports={})
202+
203+
@module.combinational
204+
def build(self, boothencoder: BoothEncoder):
205+
cnt = RegArray(Int(32), 1)
206+
(cnt & self)[0] <= cnt[0] + Int(32)(1)
207+
cond = cnt[0] < Int(32)(95)
208+
# test input from 0 to 94
209+
input_a = RegArray(Int(32),1)
210+
input_b = RegArray(Int(32),1)
211+
(input_a & self)[0] <= input_a[0] + Int(32)(1)
212+
(input_b & self)[0] <= input_b[0] + Int(32)(1)
213+
with Condition(cond):
214+
boothencoder.async_called(a=input_a[0], b=input_b[0], cnt=cnt[0])
215+
216+
def build_system():
217+
218+
finaladder = FinalAdder()
219+
finaladder.build()
220+
csatree = CSATree()
221+
csatree.build(finaladder)
222+
boothencoder = BoothEncoder()
223+
boothencoder.build(csatree)
224+
driver = Driver()
225+
driver.build(boothencoder)
226+
227+
def check_raw(raw):
228+
cnt = 0
229+
for i in raw.split('\n'):
230+
if 'Final result' in i:
231+
line_toks = i.split()
232+
c = line_toks[-1]
233+
b = line_toks[-3]
234+
a = line_toks[-5]
235+
assert int(a) * int(b) == int(c)
236+
cnt += 1
237+
238+
def test_multiplier():
239+
run_test('multiplier_test', build_system, check_raw)
240+
241+
if __name__ == '__main__':
242+
test_multiplier()

0 commit comments

Comments
 (0)