Skip to content

Commit 1fe5532

Browse files
committed
Allow MHA plugin to run on SM_86 as well
Signed-off-by: Rajeev Rao <rajeevrao@nvidia.com>
1 parent 853d331 commit 1fe5532

2 files changed

Lines changed: 3 additions & 2 deletions

File tree

plugin/bertQKVToContextPlugin/qkvToContextInt8InterleavedPlugin.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ QKVToContextInterleavedPlugin::QKVToContextInterleavedPlugin(
6262
mSM = getSMVersion();
6363
// variable sequence length is only supported with the fused MHA kernels
6464
// we should not override mS!
65-
assert((mSM == kSM_AMPERE || mSM == kSM_TURING || mSM == kSM_XAVIER)
65+
assert((mSM == kSM_AMPERE_100 || mSM == kSM_AMPERE_10X || mSM == kSM_TURING || mSM == kSM_XAVIER)
6666
&& "requesting maxSeqlen not compatible with GPU arch");
6767
// the layout changes: SxB will be a combined \sum_i s_i and hdim will be the 2nd dimension instead of the third
6868
mXmmaKernel = getXMMAKernelsV2(DATA_TYPE_INT8, mSM);

plugin/bertQKVToContextPlugin/qkvToContextInt8InterleavedPlugin.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ namespace bert
2929
{
3030
static constexpr int32_t kSM_XAVIER = 72;
3131
static constexpr int32_t kSM_TURING = 75;
32-
static constexpr int32_t kSM_AMPERE = 80;
32+
static constexpr int32_t kSM_AMPERE_100 = 80;
33+
static constexpr int32_t kSM_AMPERE_10X = 86;
3334

3435
class QKVToContextInterleavedPlugin : public nvinfer1::IPluginV2DynamicExt
3536
{

0 commit comments

Comments
 (0)