Skip to content

Commit f67d4d7

Browse files
committed
bpu: add BTBMGSC unit test and refactor initialization
- Added conditional compilation for unit testing, allowing for a separate constructor and initialization logic. - Introduced debug flags and test-specific structures to facilitate testing. - Refactored the initStorage method to improve clarity and organization of storage initialization. - Created a new test file for BTBMGSC to validate its functionality with unit tests. - Updated SConscript to include the new test file in the build process. Change-Id: Ib1ab026117ef87ebbf7097a8d32710843d32ed5b
1 parent 70d4415 commit f67d4d7

File tree

5 files changed

+298
-130
lines changed

5 files changed

+298
-130
lines changed

src/cpu/pred/btb/btb_mgsc.cc

Lines changed: 154 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,19 @@
11
#include "cpu/pred/btb/btb_mgsc.hh"
22

3-
#ifndef UNIT_TEST
3+
#include "base/intmath.hh"
4+
5+
#ifdef UNIT_TEST
6+
#include "cpu/pred/btb/test/test_dprintf.hh"
7+
8+
// Define debug flags for unit testing
9+
namespace gem5 {
10+
namespace debug {
11+
bool MGSC = true;
12+
}
13+
}
14+
#else
415
#include "cpu/o3/dyn_inst.hh"
16+
#include "debug/MGSC.hh"
517

618
#endif
719

@@ -13,8 +25,6 @@
1325
#include <type_traits>
1426
#include <vector>
1527

16-
#include "debug/MGSC.hh"
17-
1828
namespace gem5
1929
{
2030

@@ -24,6 +34,133 @@ namespace branch_prediction
2434
namespace btb_pred
2535
{
2636

37+
#ifdef UNIT_TEST
38+
namespace test
39+
{
40+
#endif
41+
42+
void
43+
BTBMGSC::initStorage()
44+
{
45+
auto pow2 = [](unsigned width) -> uint64_t {
46+
assert(width < 63);
47+
return 1ULL << width;
48+
};
49+
auto allocPredTable = [&](std::vector<std::vector<std::vector<int16_t>>> &table, unsigned numTables,
50+
unsigned idxWidth) -> uint64_t {
51+
table.resize(numTables);
52+
auto tableSize = pow2(idxWidth);
53+
assert(tableSize > numCtrsPerLine);
54+
for (unsigned int i = 0; i < numTables; ++i) {
55+
table[i].resize(tableSize / numCtrsPerLine, std::vector<int16_t>(numCtrsPerLine, 0));
56+
}
57+
return tableSize;
58+
};
59+
60+
assert(isPowerOf2(numCtrsPerLine));
61+
numCtrsPerLineBits = log2i(numCtrsPerLine);
62+
63+
auto bwTableSize = allocPredTable(bwTable, bwTableNum, bwTableIdxWidth);
64+
for (unsigned int i = 0; i < bwTableNum; ++i) {
65+
indexBwFoldedHist.push_back(GlobalBwFoldedHist(bwHistLen[i], bwTableIdxWidth - numCtrsPerLineBits, 16));
66+
}
67+
bwIndex.resize(bwTableNum);
68+
69+
auto lTableSize = allocPredTable(lTable, lTableNum, lTableIdxWidth);
70+
indexLFoldedHist.resize(numEntriesFirstLocalHistories);
71+
for (unsigned int i = 0; i < lTableNum; ++i) {
72+
for (unsigned int k = 0; k < numEntriesFirstLocalHistories; ++k) {
73+
indexLFoldedHist[k].push_back(LocalFoldedHist(lHistLen[i], lTableIdxWidth - numCtrsPerLineBits, 16));
74+
}
75+
}
76+
lIndex.resize(lTableNum);
77+
78+
auto iTableSize = allocPredTable(iTable, iTableNum, iTableIdxWidth);
79+
for (unsigned int i = 0; i < iTableNum; ++i) {
80+
assert(iHistLen[i] >= 0);
81+
assert(static_cast<unsigned>(iHistLen[i]) < 63);
82+
assert(pow2(static_cast<unsigned>(iHistLen[i])) <= iTableSize);
83+
indexIFoldedHist.push_back(ImliFoldedHist(iHistLen[i], iTableIdxWidth - numCtrsPerLineBits, 16));
84+
}
85+
iIndex.resize(iTableNum);
86+
87+
auto gTableSize = allocPredTable(gTable, gTableNum, gTableIdxWidth);
88+
for (unsigned int i = 0; i < gTableNum; ++i) {
89+
assert(gTable.size() >= gTableNum);
90+
indexGFoldedHist.push_back(GlobalFoldedHist(gHistLen[i], gTableIdxWidth - numCtrsPerLineBits, 16));
91+
}
92+
gIndex.resize(gTableNum);
93+
94+
auto pTableSize = allocPredTable(pTable, pTableNum, pTableIdxWidth);
95+
for (unsigned int i = 0; i < pTableNum; ++i) {
96+
assert(pTable.size() >= pTableNum);
97+
indexPFoldedHist.push_back(PathFoldedHist(pHistLen[i], pTableIdxWidth - numCtrsPerLineBits, 2));
98+
}
99+
pIndex.resize(pTableNum);
100+
101+
allocPredTable(biasTable, biasTableNum, biasTableIdxWidth);
102+
biasIndex.resize(biasTableNum);
103+
104+
auto weightTableSize = pow2(weightTableIdxWidth);
105+
bwWeightTable.resize(weightTableSize);
106+
lWeightTable.resize(weightTableSize);
107+
iWeightTable.resize(weightTableSize);
108+
gWeightTable.resize(weightTableSize);
109+
pWeightTable.resize(weightTableSize);
110+
biasWeightTable.resize(weightTableSize);
111+
112+
pUpdateThreshold.resize(pow2(thresholdTablelogSize));
113+
}
114+
115+
#ifdef UNIT_TEST
116+
BTBMGSC::BTBMGSC()
117+
: TimedBaseBTBPredictor(),
118+
bwTableNum(1),
119+
bwTableIdxWidth(4),
120+
bwHistLen({4}),
121+
numEntriesFirstLocalHistories(4),
122+
lTableNum(1),
123+
lTableIdxWidth(4),
124+
lHistLen({4}),
125+
iTableNum(1),
126+
iTableIdxWidth(4),
127+
// `ImliFoldedHist` requires foldedLen >= histLen. With `numCtrsPerLine=8` and `iTableIdxWidth=4`,
128+
// foldedLen is small (4 - log2(8) = 1), so keep histLen=1 for unit tests.
129+
iHistLen({1}),
130+
gTableNum(1),
131+
gTableIdxWidth(4),
132+
gHistLen({4}),
133+
pTableNum(1),
134+
pTableIdxWidth(4),
135+
pHistLen({4}),
136+
biasTableNum(1),
137+
biasTableIdxWidth(4),
138+
scCountersWidth(6),
139+
thresholdTablelogSize(4),
140+
updateThresholdWidth(8),
141+
pUpdateThresholdWidth(8),
142+
extraWeightsWidth(6),
143+
weightTableIdxWidth(4),
144+
// Keep consistent with `src/cpu/pred/BranchPredictor.py` default (8 counters per SRAM line).
145+
// This models "read a whole SRAM line, then pick a lane" behavior in `posHash()`.
146+
numCtrsPerLine(8),
147+
forceUseSC(false),
148+
enableBwTable(true),
149+
enableLTable(true),
150+
enableITable(true),
151+
enableGTable(true),
152+
enablePTable(true),
153+
enableBiasTable(true),
154+
enablePCThreshold(false),
155+
mgscStats()
156+
{
157+
// Test-only small config: keep tables tiny and deterministic for fast unit tests.
158+
needMoreHistories = false;
159+
160+
initStorage();
161+
updateThreshold = 35 * 8;
162+
}
163+
#else
27164
// Constructor: Initialize MGSC predictor with given parameters
28165
BTBMGSC::BTBMGSC(const Params &p)
29166
: TimedBaseBTBPredictor(p),
@@ -64,82 +201,13 @@ BTBMGSC::BTBMGSC(const Params &p)
64201
{
65202
DPRINTF(MGSC, "BTBMGSC constructor\n");
66203
this->needMoreHistories = p.needMoreHistories;
67-
68-
assert(isPowerOf2(numCtrsPerLine));
69-
numCtrsPerLineBits = log2i(numCtrsPerLine);
70-
71-
bwTable.resize(bwTableNum);
72-
auto bwTableSize = std::pow(2, bwTableIdxWidth);
73-
assert(bwTableSize > numCtrsPerLine);
74-
for (unsigned int i = 0; i < bwTableNum; ++i) {
75-
bwTable[i].resize(bwTableSize / numCtrsPerLine, std::vector<int16_t>(numCtrsPerLine, 0));
76-
indexBwFoldedHist.push_back(GlobalBwFoldedHist(bwHistLen[i], bwTableIdxWidth - numCtrsPerLineBits, 16));
77-
}
78-
bwIndex.resize(bwTableNum);
79-
80-
lTable.resize(lTableNum);
81-
indexLFoldedHist.resize(numEntriesFirstLocalHistories);
82-
auto lTableSize = std::pow(2, lTableIdxWidth);
83-
assert(lTableSize > numCtrsPerLine);
84-
for (unsigned int i = 0; i < lTableNum; ++i) {
85-
lTable[i].resize(lTableSize / numCtrsPerLine, std::vector<int16_t>(numCtrsPerLine, 0));
86-
for (unsigned int k = 0; k < numEntriesFirstLocalHistories; ++k) {
87-
indexLFoldedHist[k].push_back(LocalFoldedHist(lHistLen[i], lTableIdxWidth - numCtrsPerLineBits, 16));
88-
}
89-
}
90-
lIndex.resize(lTableNum);
91-
92-
iTable.resize(iTableNum);
93-
auto iTableSize = std::pow(2, iTableIdxWidth);
94-
assert(iTableSize > numCtrsPerLine);
95-
for (unsigned int i = 0; i < iTableNum; ++i) {
96-
assert(std::pow(2, iHistLen[i]) <= iTableSize);
97-
iTable[i].resize(iTableSize / numCtrsPerLine, std::vector<int16_t>(numCtrsPerLine, 0));
98-
indexIFoldedHist.push_back(ImliFoldedHist(iHistLen[i], iTableIdxWidth - numCtrsPerLineBits, 16));
99-
}
100-
iIndex.resize(iTableNum);
101-
102-
gTable.resize(gTableNum);
103-
auto gTableSize = std::pow(2, gTableIdxWidth);
104-
assert(gTableSize > numCtrsPerLine);
105-
for (unsigned int i = 0; i < gTableNum; ++i) {
106-
assert(gTable.size() >= gTableNum);
107-
gTable[i].resize(gTableSize / numCtrsPerLine, std::vector<int16_t>(numCtrsPerLine, 0));
108-
indexGFoldedHist.push_back(GlobalFoldedHist(gHistLen[i], gTableIdxWidth - numCtrsPerLineBits, 16));
109-
}
110-
gIndex.resize(gTableNum);
111-
112-
pTable.resize(pTableNum);
113-
auto pTableSize = std::pow(2, pTableIdxWidth);
114-
assert(pTableSize > numCtrsPerLine);
115-
for (unsigned int i = 0; i < pTableNum; ++i) {
116-
assert(pTable.size() >= pTableNum);
117-
pTable[i].resize(pTableSize / numCtrsPerLine, std::vector<int16_t>(numCtrsPerLine, 0));
118-
indexPFoldedHist.push_back(PathFoldedHist(pHistLen[i], pTableIdxWidth - numCtrsPerLineBits, 2));
119-
}
120-
pIndex.resize(pTableNum);
121-
122-
biasTable.resize(biasTableNum);
123-
auto biasTableSize = std::pow(2, biasTableIdxWidth);
124-
assert(biasTableSize > numCtrsPerLine);
125-
for (unsigned int i = 0; i < biasTableNum; ++i) {
126-
biasTable[i].resize(biasTableSize / numCtrsPerLine, std::vector<int16_t>(numCtrsPerLine, 0));
127-
}
128-
biasIndex.resize(biasTableNum);
129-
130-
bwWeightTable.resize(std::pow(2, weightTableIdxWidth));
131-
lWeightTable.resize(std::pow(2, weightTableIdxWidth));
132-
iWeightTable.resize(std::pow(2, weightTableIdxWidth));
133-
gWeightTable.resize(std::pow(2, weightTableIdxWidth));
134-
pWeightTable.resize(std::pow(2, weightTableIdxWidth));
135-
biasWeightTable.resize(std::pow(2, weightTableIdxWidth));
136-
pUpdateThreshold.resize(std::pow(2, thresholdTablelogSize));
137-
204+
initStorage();
138205
updateThreshold = 35 * 8;
139206

140207
hasDB = true;
141208
dbName = std::string("mgsc");
142209
}
210+
#endif
143211
BTBMGSC::~BTBMGSC() {}
144212

145213
// Set up tracing for debugging
@@ -601,7 +669,7 @@ BTBMGSC::recordPredictionStats(const MgscPrediction &pred, bool actual_taken, bo
601669
}
602670

603671
// Record raw percsum correctness and weight criticality for each table
604-
auto recordPercsum = [&](int percsum, statistics::Scalar &correct, statistics::Scalar &wrong) {
672+
auto recordPercsum = [&](int percsum, auto &correct, auto &wrong) {
605673
if ((percsum >= 0) == actual_taken) {
606674
correct++;
607675
} else {
@@ -1200,10 +1268,11 @@ BTBMGSC::recoverLHist(const std::vector<boost::dynamic_bitset<>> &history, const
12001268
indexLFoldedHist[k][i].recover(predMeta->indexLFoldedHist[k][i]);
12011269
}
12021270
}
1203-
doUpdateHist(history[getPcIndex(entry.startPC, log2(numEntriesFirstLocalHistories))], shamt, cond_taken,
1204-
indexLFoldedHist[getPcIndex(entry.startPC, log2(numEntriesFirstLocalHistories))]);
1205-
}
1271+
doUpdateHist(history[getPcIndex(entry.startPC, log2(numEntriesFirstLocalHistories))], shamt, cond_taken,
1272+
indexLFoldedHist[getPcIndex(entry.startPC, log2(numEntriesFirstLocalHistories))]);
1273+
}
12061274

1275+
#ifndef UNIT_TEST
12071276
// Constructor for TAGE statistics
12081277
BTBMGSC::MgscStats::MgscStats(statistics::Group *parent)
12091278
: statistics::Group(parent),
@@ -1262,7 +1331,9 @@ BTBMGSC::MgscStats::MgscStats(statistics::Group *parent)
12621331
ADD_STAT(scLowBypass, statistics::units::Count::get(), "tage low conf, sc not used")
12631332
{
12641333
}
1334+
#endif
12651335

1336+
#ifndef UNIT_TEST
12661337
void
12671338
BTBMGSC::commitBranch(const FetchStream &stream, const DynInstPtr &inst)
12681339
{
@@ -1313,6 +1384,7 @@ BTBMGSC::commitBranch(const FetchStream &stream, const DynInstPtr &inst)
13131384
}
13141385

13151386
}
1387+
#endif
13161388

13171389
void
13181390
BTBMGSC::checkFoldedHist(const boost::dynamic_bitset<> &Ghistory, const boost::dynamic_bitset<> &PHistory,
@@ -1341,6 +1413,10 @@ BTBMGSC::checkFoldedHist(const boost::dynamic_bitset<> &Ghistory, const boost::d
13411413
}
13421414
}
13431415

1416+
#ifdef UNIT_TEST
1417+
} // namespace test
1418+
#endif
1419+
13441420
} // namespace btb_pred
13451421

13461422
} // namespace branch_prediction

0 commit comments

Comments
 (0)