11#include " cpu/pred/btb/btb_mgsc.hh"
22
3- #ifndef UNIT_TEST
3+ #include " base/intmath.hh"
4+
5+ #ifdef UNIT_TEST
6+ #include " cpu/pred/btb/test/test_dprintf.hh"
7+
8+ // Define debug flags for unit testing
9+ namespace gem5 {
10+ namespace debug {
11+ bool MGSC = true ;
12+ }
13+ }
14+ #else
415#include " cpu/o3/dyn_inst.hh"
16+ #include " debug/MGSC.hh"
517
618#endif
719
1325#include < type_traits>
1426#include < vector>
1527
16- #include " debug/MGSC.hh"
17-
1828namespace gem5
1929{
2030
@@ -24,6 +34,133 @@ namespace branch_prediction
2434namespace btb_pred
2535{
2636
37+ #ifdef UNIT_TEST
38+ namespace test
39+ {
40+ #endif
41+
42+ void
43+ BTBMGSC::initStorage ()
44+ {
45+ auto pow2 = [](unsigned width) -> uint64_t {
46+ assert (width < 63 );
47+ return 1ULL << width;
48+ };
49+ auto allocPredTable = [&](std::vector<std::vector<std::vector<int16_t >>> &table, unsigned numTables,
50+ unsigned idxWidth) -> uint64_t {
51+ table.resize (numTables);
52+ auto tableSize = pow2 (idxWidth);
53+ assert (tableSize > numCtrsPerLine);
54+ for (unsigned int i = 0 ; i < numTables; ++i) {
55+ table[i].resize (tableSize / numCtrsPerLine, std::vector<int16_t >(numCtrsPerLine, 0 ));
56+ }
57+ return tableSize;
58+ };
59+
60+ assert (isPowerOf2 (numCtrsPerLine));
61+ numCtrsPerLineBits = log2i (numCtrsPerLine);
62+
63+ auto bwTableSize = allocPredTable (bwTable, bwTableNum, bwTableIdxWidth);
64+ for (unsigned int i = 0 ; i < bwTableNum; ++i) {
65+ indexBwFoldedHist.push_back (GlobalBwFoldedHist (bwHistLen[i], bwTableIdxWidth - numCtrsPerLineBits, 16 ));
66+ }
67+ bwIndex.resize (bwTableNum);
68+
69+ auto lTableSize = allocPredTable (lTable, lTableNum, lTableIdxWidth);
70+ indexLFoldedHist.resize (numEntriesFirstLocalHistories);
71+ for (unsigned int i = 0 ; i < lTableNum; ++i) {
72+ for (unsigned int k = 0 ; k < numEntriesFirstLocalHistories; ++k) {
73+ indexLFoldedHist[k].push_back (LocalFoldedHist (lHistLen[i], lTableIdxWidth - numCtrsPerLineBits, 16 ));
74+ }
75+ }
76+ lIndex.resize (lTableNum);
77+
78+ auto iTableSize = allocPredTable (iTable, iTableNum, iTableIdxWidth);
79+ for (unsigned int i = 0 ; i < iTableNum; ++i) {
80+ assert (iHistLen[i] >= 0 );
81+ assert (static_cast <unsigned >(iHistLen[i]) < 63 );
82+ assert (pow2 (static_cast <unsigned >(iHistLen[i])) <= iTableSize);
83+ indexIFoldedHist.push_back (ImliFoldedHist (iHistLen[i], iTableIdxWidth - numCtrsPerLineBits, 16 ));
84+ }
85+ iIndex.resize (iTableNum);
86+
87+ auto gTableSize = allocPredTable (gTable , gTableNum , gTableIdxWidth );
88+ for (unsigned int i = 0 ; i < gTableNum ; ++i) {
89+ assert (gTable .size () >= gTableNum );
90+ indexGFoldedHist.push_back (GlobalFoldedHist (gHistLen [i], gTableIdxWidth - numCtrsPerLineBits, 16 ));
91+ }
92+ gIndex .resize (gTableNum );
93+
94+ auto pTableSize = allocPredTable (pTable, pTableNum, pTableIdxWidth);
95+ for (unsigned int i = 0 ; i < pTableNum; ++i) {
96+ assert (pTable.size () >= pTableNum);
97+ indexPFoldedHist.push_back (PathFoldedHist (pHistLen[i], pTableIdxWidth - numCtrsPerLineBits, 2 ));
98+ }
99+ pIndex.resize (pTableNum);
100+
101+ allocPredTable (biasTable, biasTableNum, biasTableIdxWidth);
102+ biasIndex.resize (biasTableNum);
103+
104+ auto weightTableSize = pow2 (weightTableIdxWidth);
105+ bwWeightTable.resize (weightTableSize);
106+ lWeightTable.resize (weightTableSize);
107+ iWeightTable.resize (weightTableSize);
108+ gWeightTable .resize (weightTableSize);
109+ pWeightTable.resize (weightTableSize);
110+ biasWeightTable.resize (weightTableSize);
111+
112+ pUpdateThreshold.resize (pow2 (thresholdTablelogSize));
113+ }
114+
115+ #ifdef UNIT_TEST
116+ BTBMGSC::BTBMGSC ()
117+ : TimedBaseBTBPredictor(),
118+ bwTableNum(1 ),
119+ bwTableIdxWidth(4 ),
120+ bwHistLen({4 }),
121+ numEntriesFirstLocalHistories(4 ),
122+ lTableNum(1 ),
123+ lTableIdxWidth(4 ),
124+ lHistLen({4 }),
125+ iTableNum(1 ),
126+ iTableIdxWidth(4 ),
127+ // `ImliFoldedHist` requires foldedLen >= histLen. With `numCtrsPerLine=8` and `iTableIdxWidth=4`,
128+ // foldedLen is small (4 - log2(8) = 1), so keep histLen=1 for unit tests.
129+ iHistLen({1 }),
130+ gTableNum(1 ),
131+ gTableIdxWidth(4 ),
132+ gHistLen({4 }),
133+ pTableNum(1 ),
134+ pTableIdxWidth(4 ),
135+ pHistLen({4 }),
136+ biasTableNum(1 ),
137+ biasTableIdxWidth(4 ),
138+ scCountersWidth(6 ),
139+ thresholdTablelogSize(4 ),
140+ updateThresholdWidth(8 ),
141+ pUpdateThresholdWidth(8 ),
142+ extraWeightsWidth(6 ),
143+ weightTableIdxWidth(4 ),
144+ // Keep consistent with `src/cpu/pred/BranchPredictor.py` default (8 counters per SRAM line).
145+ // This models "read a whole SRAM line, then pick a lane" behavior in `posHash()`.
146+ numCtrsPerLine(8 ),
147+ forceUseSC(false ),
148+ enableBwTable(true ),
149+ enableLTable(true ),
150+ enableITable(true ),
151+ enableGTable(true ),
152+ enablePTable(true ),
153+ enableBiasTable(true ),
154+ enablePCThreshold(false ),
155+ mgscStats()
156+ {
157+ // Test-only small config: keep tables tiny and deterministic for fast unit tests.
158+ needMoreHistories = false ;
159+
160+ initStorage ();
161+ updateThreshold = 35 * 8 ;
162+ }
163+ #else
27164// Constructor: Initialize MGSC predictor with given parameters
28165BTBMGSC::BTBMGSC (const Params &p)
29166 : TimedBaseBTBPredictor(p),
@@ -64,82 +201,13 @@ BTBMGSC::BTBMGSC(const Params &p)
64201{
65202 DPRINTF (MGSC, " BTBMGSC constructor\n " );
66203 this ->needMoreHistories = p.needMoreHistories ;
67-
68- assert (isPowerOf2 (numCtrsPerLine));
69- numCtrsPerLineBits = log2i (numCtrsPerLine);
70-
71- bwTable.resize (bwTableNum);
72- auto bwTableSize = std::pow (2 , bwTableIdxWidth);
73- assert (bwTableSize > numCtrsPerLine);
74- for (unsigned int i = 0 ; i < bwTableNum; ++i) {
75- bwTable[i].resize (bwTableSize / numCtrsPerLine, std::vector<int16_t >(numCtrsPerLine, 0 ));
76- indexBwFoldedHist.push_back (GlobalBwFoldedHist (bwHistLen[i], bwTableIdxWidth - numCtrsPerLineBits, 16 ));
77- }
78- bwIndex.resize (bwTableNum);
79-
80- lTable.resize (lTableNum);
81- indexLFoldedHist.resize (numEntriesFirstLocalHistories);
82- auto lTableSize = std::pow (2 , lTableIdxWidth);
83- assert (lTableSize > numCtrsPerLine);
84- for (unsigned int i = 0 ; i < lTableNum; ++i) {
85- lTable[i].resize (lTableSize / numCtrsPerLine, std::vector<int16_t >(numCtrsPerLine, 0 ));
86- for (unsigned int k = 0 ; k < numEntriesFirstLocalHistories; ++k) {
87- indexLFoldedHist[k].push_back (LocalFoldedHist (lHistLen[i], lTableIdxWidth - numCtrsPerLineBits, 16 ));
88- }
89- }
90- lIndex.resize (lTableNum);
91-
92- iTable.resize (iTableNum);
93- auto iTableSize = std::pow (2 , iTableIdxWidth);
94- assert (iTableSize > numCtrsPerLine);
95- for (unsigned int i = 0 ; i < iTableNum; ++i) {
96- assert (std::pow (2 , iHistLen[i]) <= iTableSize);
97- iTable[i].resize (iTableSize / numCtrsPerLine, std::vector<int16_t >(numCtrsPerLine, 0 ));
98- indexIFoldedHist.push_back (ImliFoldedHist (iHistLen[i], iTableIdxWidth - numCtrsPerLineBits, 16 ));
99- }
100- iIndex.resize (iTableNum);
101-
102- gTable .resize (gTableNum );
103- auto gTableSize = std::pow (2 , gTableIdxWidth );
104- assert (gTableSize > numCtrsPerLine);
105- for (unsigned int i = 0 ; i < gTableNum ; ++i) {
106- assert (gTable .size () >= gTableNum );
107- gTable [i].resize (gTableSize / numCtrsPerLine, std::vector<int16_t >(numCtrsPerLine, 0 ));
108- indexGFoldedHist.push_back (GlobalFoldedHist (gHistLen [i], gTableIdxWidth - numCtrsPerLineBits, 16 ));
109- }
110- gIndex .resize (gTableNum );
111-
112- pTable.resize (pTableNum);
113- auto pTableSize = std::pow (2 , pTableIdxWidth);
114- assert (pTableSize > numCtrsPerLine);
115- for (unsigned int i = 0 ; i < pTableNum; ++i) {
116- assert (pTable.size () >= pTableNum);
117- pTable[i].resize (pTableSize / numCtrsPerLine, std::vector<int16_t >(numCtrsPerLine, 0 ));
118- indexPFoldedHist.push_back (PathFoldedHist (pHistLen[i], pTableIdxWidth - numCtrsPerLineBits, 2 ));
119- }
120- pIndex.resize (pTableNum);
121-
122- biasTable.resize (biasTableNum);
123- auto biasTableSize = std::pow (2 , biasTableIdxWidth);
124- assert (biasTableSize > numCtrsPerLine);
125- for (unsigned int i = 0 ; i < biasTableNum; ++i) {
126- biasTable[i].resize (biasTableSize / numCtrsPerLine, std::vector<int16_t >(numCtrsPerLine, 0 ));
127- }
128- biasIndex.resize (biasTableNum);
129-
130- bwWeightTable.resize (std::pow (2 , weightTableIdxWidth));
131- lWeightTable.resize (std::pow (2 , weightTableIdxWidth));
132- iWeightTable.resize (std::pow (2 , weightTableIdxWidth));
133- gWeightTable .resize (std::pow (2 , weightTableIdxWidth));
134- pWeightTable.resize (std::pow (2 , weightTableIdxWidth));
135- biasWeightTable.resize (std::pow (2 , weightTableIdxWidth));
136- pUpdateThreshold.resize (std::pow (2 , thresholdTablelogSize));
137-
204+ initStorage ();
138205 updateThreshold = 35 * 8 ;
139206
140207 hasDB = true ;
141208 dbName = std::string (" mgsc" );
142209}
210+ #endif
143211BTBMGSC::~BTBMGSC () {}
144212
145213// Set up tracing for debugging
@@ -601,7 +669,7 @@ BTBMGSC::recordPredictionStats(const MgscPrediction &pred, bool actual_taken, bo
601669 }
602670
603671 // Record raw percsum correctness and weight criticality for each table
604- auto recordPercsum = [&](int percsum, statistics::Scalar &correct, statistics::Scalar &wrong) {
672+ auto recordPercsum = [&](int percsum, auto &correct, auto &wrong) {
605673 if ((percsum >= 0 ) == actual_taken) {
606674 correct++;
607675 } else {
@@ -1200,10 +1268,11 @@ BTBMGSC::recoverLHist(const std::vector<boost::dynamic_bitset<>> &history, const
12001268 indexLFoldedHist[k][i].recover (predMeta->indexLFoldedHist [k][i]);
12011269 }
12021270 }
1203- doUpdateHist (history[getPcIndex (entry.startPC , log2 (numEntriesFirstLocalHistories))], shamt, cond_taken,
1204- indexLFoldedHist[getPcIndex (entry.startPC , log2 (numEntriesFirstLocalHistories))]);
1205- }
1271+ doUpdateHist (history[getPcIndex (entry.startPC , log2 (numEntriesFirstLocalHistories))], shamt, cond_taken,
1272+ indexLFoldedHist[getPcIndex (entry.startPC , log2 (numEntriesFirstLocalHistories))]);
1273+ }
12061274
1275+ #ifndef UNIT_TEST
12071276// Constructor for TAGE statistics
12081277BTBMGSC::MgscStats::MgscStats (statistics::Group *parent)
12091278 : statistics::Group(parent),
@@ -1262,7 +1331,9 @@ BTBMGSC::MgscStats::MgscStats(statistics::Group *parent)
12621331 ADD_STAT(scLowBypass, statistics::units::Count::get(), "tage low conf, sc not used")
12631332{
12641333}
1334+ #endif
12651335
1336+ #ifndef UNIT_TEST
12661337void
12671338BTBMGSC::commitBranch (const FetchStream &stream, const DynInstPtr &inst)
12681339{
@@ -1313,6 +1384,7 @@ BTBMGSC::commitBranch(const FetchStream &stream, const DynInstPtr &inst)
13131384 }
13141385
13151386}
1387+ #endif
13161388
13171389void
13181390BTBMGSC::checkFoldedHist (const boost::dynamic_bitset<> &Ghistory, const boost::dynamic_bitset<> &PHistory,
@@ -1341,6 +1413,10 @@ BTBMGSC::checkFoldedHist(const boost::dynamic_bitset<> &Ghistory, const boost::d
13411413 }
13421414}
13431415
1416+ #ifdef UNIT_TEST
1417+ } // namespace test
1418+ #endif
1419+
13441420} // namespace btb_pred
13451421
13461422} // namespace branch_prediction
0 commit comments