Skip to content
Merged

Sc ut #710

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 179 additions & 85 deletions src/cpu/pred/btb/btb_mgsc.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
#include "cpu/pred/btb/btb_mgsc.hh"

#ifndef UNIT_TEST
#include "base/intmath.hh"

#ifdef UNIT_TEST
#include "cpu/pred/btb/test/test_dprintf.hh"

// Define debug flags for unit testing
namespace gem5 {
namespace debug {
bool MGSC = true;
}
}
#else
#include "cpu/o3/dyn_inst.hh"
#include "debug/MGSC.hh"

#endif

Expand All @@ -13,8 +25,6 @@
#include <type_traits>
#include <vector>

#include "debug/MGSC.hh"

namespace gem5
{

Expand All @@ -24,6 +34,140 @@ namespace branch_prediction
namespace btb_pred
{

#ifdef UNIT_TEST
namespace test
{
#endif

void
BTBMGSC::initStorage()
{
auto pow2 = [](unsigned width) -> uint64_t {
assert(width < 63);
return 1ULL << width;
};
auto allocPredTable = [&](std::vector<std::vector<std::vector<int16_t>>> &table, unsigned numTables,
unsigned idxWidth) -> uint64_t {
table.resize(numTables);
auto tableSize = pow2(idxWidth);
assert(tableSize > numCtrsPerLine);
for (unsigned int i = 0; i < numTables; ++i) {
table[i].resize(tableSize / numCtrsPerLine, std::vector<int16_t>(numCtrsPerLine, 0));
}
return tableSize;
};

assert(isPowerOf2(numCtrsPerLine));
numCtrsPerLineBits = log2i(numCtrsPerLine);

auto bwTableSize = allocPredTable(bwTable, bwTableNum, bwTableIdxWidth);
for (unsigned int i = 0; i < bwTableNum; ++i) {
indexBwFoldedHist.push_back(GlobalBwFoldedHist(bwHistLen[i], bwTableIdxWidth - numCtrsPerLineBits, 16));
}
bwIndex.resize(bwTableNum);

auto lTableSize = allocPredTable(lTable, lTableNum, lTableIdxWidth);
indexLFoldedHist.resize(numEntriesFirstLocalHistories);
for (unsigned int i = 0; i < lTableNum; ++i) {
for (unsigned int k = 0; k < numEntriesFirstLocalHistories; ++k) {
indexLFoldedHist[k].push_back(LocalFoldedHist(lHistLen[i], lTableIdxWidth - numCtrsPerLineBits, 16));
}
}
lIndex.resize(lTableNum);

auto iTableSize = allocPredTable(iTable, iTableNum, iTableIdxWidth);
for (unsigned int i = 0; i < iTableNum; ++i) {
assert(iHistLen[i] >= 0);
assert(static_cast<unsigned>(iHistLen[i]) < 63);
assert(pow2(static_cast<unsigned>(iHistLen[i])) <= iTableSize);
indexIFoldedHist.push_back(ImliFoldedHist(iHistLen[i], iTableIdxWidth - numCtrsPerLineBits, 16));
}
iIndex.resize(iTableNum);

auto gTableSize = allocPredTable(gTable, gTableNum, gTableIdxWidth);
for (unsigned int i = 0; i < gTableNum; ++i) {
assert(gTable.size() >= gTableNum);
indexGFoldedHist.push_back(GlobalFoldedHist(gHistLen[i], gTableIdxWidth - numCtrsPerLineBits, 16));
}
gIndex.resize(gTableNum);

auto pTableSize = allocPredTable(pTable, pTableNum, pTableIdxWidth);
for (unsigned int i = 0; i < pTableNum; ++i) {
assert(pTable.size() >= pTableNum);
indexPFoldedHist.push_back(PathFoldedHist(pHistLen[i], pTableIdxWidth - numCtrsPerLineBits, 2));
}
pIndex.resize(pTableNum);

allocPredTable(biasTable, biasTableNum, biasTableIdxWidth);
biasIndex.resize(biasTableNum);

auto weightTableSize = pow2(weightTableIdxWidth);
bwWeightTable.resize(weightTableSize);
lWeightTable.resize(weightTableSize);
iWeightTable.resize(weightTableSize);
gWeightTable.resize(weightTableSize);
pWeightTable.resize(weightTableSize);
biasWeightTable.resize(weightTableSize);

pUpdateThreshold.resize(pow2(thresholdTablelogSize));
}

#ifdef UNIT_TEST
BTBMGSC::BTBMGSC()
: TimedBaseBTBPredictor(),
bwTableNum(1),
// Use a slightly larger idx width so foldedLen is not too small (helps pattern-learning tests).
bwTableIdxWidth(6),
bwHistLen({4}),
numEntriesFirstLocalHistories(4),
lTableNum(1),
// Use a slightly larger idx width so foldedLen is not too small (helps pattern-learning tests).
lTableIdxWidth(6),
lHistLen({4}),
iTableNum(1),
iTableIdxWidth(5),
// `ImliFoldedHist` requires foldedLen >= histLen. With `numCtrsPerLine=8` and `iTableIdxWidth=5`,
// foldedLen is small (5 - log2(8) = 2), so keep histLen=1 for unit tests.
// Also keep it >= 2 so we can build loop-trip-count tests on IMLI.
iHistLen({2}),
gTableNum(1),
// Use a slightly larger idx width so foldedLen is not too small (helps pattern-learning tests).
gTableIdxWidth(6),
gHistLen({4}),
pTableNum(1),
// Use a slightly larger idx width so foldedLen is not too small (helps pattern-learning tests).
pTableIdxWidth(6),
pHistLen({4}),
biasTableNum(1),
biasTableIdxWidth(5),
scCountersWidth(6),
thresholdTablelogSize(4),
updateThresholdWidth(12),
pUpdateThresholdWidth(8),
extraWeightsWidth(6),
weightTableIdxWidth(4),
// Keep consistent with `src/cpu/pred/BranchPredictor.py` default (8 counters per SRAM line).
// This models "read a whole SRAM line, then pick a lane" behavior in `posHash()`.
numCtrsPerLine(8),
forceUseSC(false),
enableBwTable(true),
enableLTable(true),
enableITable(true),
enableGTable(true),
enablePTable(true),
enableBiasTable(true),
enablePCThreshold(false),
mgscStats()
{
// Test-only small config: keep tables tiny and deterministic for fast unit tests.
// MGSC uses multiple histories (GHR/PHR/BWHR/LHR). Keep it enabled in unit tests so we can
// build training-loop style tests that exercise each table.
needMoreHistories = true;

initStorage();
updateThreshold = 35 * 8;
}
#else
// Constructor: Initialize MGSC predictor with given parameters
BTBMGSC::BTBMGSC(const Params &p)
: TimedBaseBTBPredictor(p),
Expand Down Expand Up @@ -64,82 +208,13 @@ BTBMGSC::BTBMGSC(const Params &p)
{
DPRINTF(MGSC, "BTBMGSC constructor\n");
this->needMoreHistories = p.needMoreHistories;

assert(isPowerOf2(numCtrsPerLine));
numCtrsPerLineBits = log2i(numCtrsPerLine);

bwTable.resize(bwTableNum);
auto bwTableSize = std::pow(2, bwTableIdxWidth);
assert(bwTableSize > numCtrsPerLine);
for (unsigned int i = 0; i < bwTableNum; ++i) {
bwTable[i].resize(bwTableSize / numCtrsPerLine, std::vector<int16_t>(numCtrsPerLine, 0));
indexBwFoldedHist.push_back(GlobalBwFoldedHist(bwHistLen[i], bwTableIdxWidth - numCtrsPerLineBits, 16));
}
bwIndex.resize(bwTableNum);

lTable.resize(lTableNum);
indexLFoldedHist.resize(numEntriesFirstLocalHistories);
auto lTableSize = std::pow(2, lTableIdxWidth);
assert(lTableSize > numCtrsPerLine);
for (unsigned int i = 0; i < lTableNum; ++i) {
lTable[i].resize(lTableSize / numCtrsPerLine, std::vector<int16_t>(numCtrsPerLine, 0));
for (unsigned int k = 0; k < numEntriesFirstLocalHistories; ++k) {
indexLFoldedHist[k].push_back(LocalFoldedHist(lHistLen[i], lTableIdxWidth - numCtrsPerLineBits, 16));
}
}
lIndex.resize(lTableNum);

iTable.resize(iTableNum);
auto iTableSize = std::pow(2, iTableIdxWidth);
assert(iTableSize > numCtrsPerLine);
for (unsigned int i = 0; i < iTableNum; ++i) {
assert(std::pow(2, iHistLen[i]) <= iTableSize);
iTable[i].resize(iTableSize / numCtrsPerLine, std::vector<int16_t>(numCtrsPerLine, 0));
indexIFoldedHist.push_back(ImliFoldedHist(iHistLen[i], iTableIdxWidth - numCtrsPerLineBits, 16));
}
iIndex.resize(iTableNum);

gTable.resize(gTableNum);
auto gTableSize = std::pow(2, gTableIdxWidth);
assert(gTableSize > numCtrsPerLine);
for (unsigned int i = 0; i < gTableNum; ++i) {
assert(gTable.size() >= gTableNum);
gTable[i].resize(gTableSize / numCtrsPerLine, std::vector<int16_t>(numCtrsPerLine, 0));
indexGFoldedHist.push_back(GlobalFoldedHist(gHistLen[i], gTableIdxWidth - numCtrsPerLineBits, 16));
}
gIndex.resize(gTableNum);

pTable.resize(pTableNum);
auto pTableSize = std::pow(2, pTableIdxWidth);
assert(pTableSize > numCtrsPerLine);
for (unsigned int i = 0; i < pTableNum; ++i) {
assert(pTable.size() >= pTableNum);
pTable[i].resize(pTableSize / numCtrsPerLine, std::vector<int16_t>(numCtrsPerLine, 0));
indexPFoldedHist.push_back(PathFoldedHist(pHistLen[i], pTableIdxWidth - numCtrsPerLineBits, 2));
}
pIndex.resize(pTableNum);

biasTable.resize(biasTableNum);
auto biasTableSize = std::pow(2, biasTableIdxWidth);
assert(biasTableSize > numCtrsPerLine);
for (unsigned int i = 0; i < biasTableNum; ++i) {
biasTable[i].resize(biasTableSize / numCtrsPerLine, std::vector<int16_t>(numCtrsPerLine, 0));
}
biasIndex.resize(biasTableNum);

bwWeightTable.resize(std::pow(2, weightTableIdxWidth));
lWeightTable.resize(std::pow(2, weightTableIdxWidth));
iWeightTable.resize(std::pow(2, weightTableIdxWidth));
gWeightTable.resize(std::pow(2, weightTableIdxWidth));
pWeightTable.resize(std::pow(2, weightTableIdxWidth));
biasWeightTable.resize(std::pow(2, weightTableIdxWidth));
pUpdateThreshold.resize(std::pow(2, thresholdTablelogSize));

initStorage();
updateThreshold = 35 * 8;

hasDB = true;
dbName = std::string("mgsc");
}
#endif
BTBMGSC::~BTBMGSC() {}

// Set up tracing for debugging
Expand Down Expand Up @@ -284,6 +359,9 @@ BTBMGSC::generateSinglePrediction(const BTBEntry &btb_entry, const Addr &startPC
lIndex[i] = getHistIndex(startPC, lTableIdxWidth - numCtrsPerLineBits,
indexLFoldedHist[getPcIndex(startPC, log2(numEntriesFirstLocalHistories))][i].get());
}
// std::string buf;
// boost::to_string(indexLFoldedHist[getPcIndex(startPC, log2(numEntriesFirstLocalHistories))][0].getAsBitset(), buf);
// DPRINTF(MGSC, "startPC: %#lx, local index: %d, local_folded_hist: %s\n", startPC, lIndex[0], buf.c_str());

for (unsigned int i = 0; i < iTableNum; ++i) {
iIndex[i] = getHistIndex(startPC, iTableIdxWidth - numCtrsPerLineBits, indexIFoldedHist[i].get());
Expand Down Expand Up @@ -355,6 +433,10 @@ BTBMGSC::generateSinglePrediction(const BTBEntry &btb_entry, const Addr &startPC
// Final prediction, total_sum >= 0 means taken if use_sc_pred
bool taken = use_sc_pred ? (total_sum >= 0) : tage_info.tage_pred_taken;

// DPRINTF(MGSC, "global tag_index: %d, global_percsum: %d, total_sum: %d\n", gIndex[0], g_percsum, total_sum);
// DPRINTF(MGSC, "local tag_index: %d, local_percsum: %d, total_sum: %d\n", lIndex[0], l_percsum, total_sum);
// DPRINTF(MGSC, "path tag_index: %d, path_percsum: %d, total_sum: %d\n", pIndex[0], p_percsum, total_sum);

// Calculate weight scale differences
bool bw_weight_scale_diff = calculateWeightScaleDiff(total_sum, bw_scaled_percsum, bw_percsum);
bool l_weight_scale_diff = calculateWeightScaleDiff(total_sum, l_scaled_percsum, l_percsum);
Expand Down Expand Up @@ -601,7 +683,7 @@ BTBMGSC::recordPredictionStats(const MgscPrediction &pred, bool actual_taken, bo
}

// Record raw percsum correctness and weight criticality for each table
auto recordPercsum = [&](int percsum, statistics::Scalar &correct, statistics::Scalar &wrong) {
auto recordPercsum = [&](int percsum, auto &correct, auto &wrong) {
if ((percsum >= 0) == actual_taken) {
correct++;
} else {
Expand Down Expand Up @@ -1027,21 +1109,23 @@ BTBMGSC::specUpdateBwHist(const boost::dynamic_bitset<> &history, FullBTBPredict
* @brief Updates IMLI branch history for speculative execution
*
* This function updates the branch history for speculative execution
* based on the provided history and prediction information.
* based on the prediction information.
*
* It first retrieves the history information from the prediction metadata
* and then calls the doUpdateHist function to update the folded histories.
* Note: IMLI only uses counter, not history bits.
*
* @param history The current imli branch history
* @param pred The prediction metadata containing history information
*/
void
BTBMGSC::specUpdateIHist(const boost::dynamic_bitset<> &history, FullBTBPrediction &pred)
BTBMGSC::specUpdateIHist(FullBTBPrediction &pred)
{
int shamt;
bool cond_taken;
std::tie(shamt, cond_taken) = pred.getBwHistInfo();
doUpdateHist(history, shamt, cond_taken, indexIFoldedHist);
// IMLI uses counter only, pass empty bitset (not used by ImliFoldedHist::update)
boost::dynamic_bitset<> dummy;
doUpdateHist(dummy, shamt, cond_taken, indexIFoldedHist);
}

/**
Expand Down Expand Up @@ -1151,14 +1235,14 @@ BTBMGSC::recoverBwHist(const boost::dynamic_bitset<> &history, const FetchStream
* 1. Restores the folded histories from the saved metadata
* 2. Updates the histories with the correct branch outcome
* 3. Ensures predictor state is consistent after recovery
* Note: IMLI only uses counter, not history bits.
*
* @param history The branch history to recover to
* @param entry The fetch stream entry containing recovery information
* @param shamt Number of bits to shift in history update
* @param cond_taken The actual branch outcome
*/
void
BTBMGSC::recoverIHist(const boost::dynamic_bitset<> &history, const FetchStream &entry, int shamt, bool cond_taken)
BTBMGSC::recoverIHist(const FetchStream &entry, int shamt, bool cond_taken)
{
if (!isEnabled()) {
return; // No recover when disabled
Expand All @@ -1167,7 +1251,9 @@ BTBMGSC::recoverIHist(const boost::dynamic_bitset<> &history, const FetchStream
for (int i = 0; i < iTableNum; i++) {
indexIFoldedHist[i].recover(predMeta->indexIFoldedHist[i]);
}
doUpdateHist(history, shamt, cond_taken, indexIFoldedHist);
// IMLI uses counter only, pass empty bitset (not used by ImliFoldedHist::update)
boost::dynamic_bitset<> dummy;
doUpdateHist(dummy, shamt, cond_taken, indexIFoldedHist);
}

/**
Expand Down Expand Up @@ -1196,10 +1282,11 @@ BTBMGSC::recoverLHist(const std::vector<boost::dynamic_bitset<>> &history, const
indexLFoldedHist[k][i].recover(predMeta->indexLFoldedHist[k][i]);
}
}
doUpdateHist(history[getPcIndex(entry.startPC, log2(numEntriesFirstLocalHistories))], shamt, cond_taken,
indexLFoldedHist[getPcIndex(entry.startPC, log2(numEntriesFirstLocalHistories))]);
}
doUpdateHist(history[getPcIndex(entry.startPC, log2(numEntriesFirstLocalHistories))], shamt, cond_taken,
indexLFoldedHist[getPcIndex(entry.startPC, log2(numEntriesFirstLocalHistories))]);
}

#ifndef UNIT_TEST
// Constructor for TAGE statistics
BTBMGSC::MgscStats::MgscStats(statistics::Group *parent)
: statistics::Group(parent),
Expand Down Expand Up @@ -1258,7 +1345,9 @@ BTBMGSC::MgscStats::MgscStats(statistics::Group *parent)
ADD_STAT(scLowBypass, statistics::units::Count::get(), "tage low conf, sc not used")
{
}
#endif

#ifndef UNIT_TEST
void
BTBMGSC::commitBranch(const FetchStream &stream, const DynInstPtr &inst)
{
Expand Down Expand Up @@ -1309,6 +1398,7 @@ BTBMGSC::commitBranch(const FetchStream &stream, const DynInstPtr &inst)
}

}
#endif

void
BTBMGSC::checkFoldedHist(const boost::dynamic_bitset<> &Ghistory, const boost::dynamic_bitset<> &PHistory,
Expand Down Expand Up @@ -1337,6 +1427,10 @@ BTBMGSC::checkFoldedHist(const boost::dynamic_bitset<> &Ghistory, const boost::d
}
}

#ifdef UNIT_TEST
} // namespace test
#endif

} // namespace btb_pred

} // namespace branch_prediction
Expand Down
Loading