forked from NVIDIA/TensorRT
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmultiscaleDeformableAttnPlugin.cpp
More file actions
293 lines (250 loc) · 9.14 KB
/
multiscaleDeformableAttnPlugin.cpp
File metadata and controls
293 lines (250 loc) · 9.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "multiscaleDeformableAttnPlugin.h"
#include "multiscaleDeformableAttn.h"
using namespace nvinfer1;
using namespace plugin;
namespace nvinfer1
{
namespace plugin
{
namespace
{
static char const* DMHA_VERSION{"1"};
static char const* DMHA_NAME{"MultiscaleDeformableAttnPlugin_TRT"};
} // namespace
MultiscaleDeformableAttnPlugin::MultiscaleDeformableAttnPlugin()
{
}
MultiscaleDeformableAttnPlugin::MultiscaleDeformableAttnPlugin(void const* data, size_t length)
{
}
nvinfer1::IPluginV2DynamicExt* MultiscaleDeformableAttnPlugin::clone() const PLUGIN_NOEXCEPT
{
try
{
MultiscaleDeformableAttnPlugin* plugin = new MultiscaleDeformableAttnPlugin();
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
catch (const std::exception& e)
{
caughtError(e);
}
return nullptr;
}
nvinfer1::DimsExprs MultiscaleDeformableAttnPlugin::getOutputDimensions(int32_t outputIndex,
nvinfer1::DimsExprs const* inputs, int32_t nbInputs, nvinfer1::IExprBuilder& exprBuilder) PLUGIN_NOEXCEPT
{
nvinfer1::DimsExprs ret;
ret.nbDims = 4;
ret.d[0] = inputs[0].d[0];
ret.d[1] = inputs[3].d[1];
ret.d[2] = inputs[0].d[2];
ret.d[3] = inputs[0].d[3];
return ret;
}
bool MultiscaleDeformableAttnPlugin::supportsFormatCombination(
int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) PLUGIN_NOEXCEPT
{
PLUGIN_ASSERT((nbInputs == 5));
PLUGIN_ASSERT((nbOutputs == 1));
if (inOut[pos].format == nvinfer1::TensorFormat::kLINEAR)
{
if ((pos == 1) || (pos == 2))
{
return (inOut[pos].type == nvinfer1::DataType::kINT32);
}
else
{
return ((inOut[pos].type == inOut[0].type) &&
((inOut[pos].type == nvinfer1::DataType::kFLOAT) || (inOut[pos].type == nvinfer1::DataType::kHALF)));
}
}
else
{
return false;
}
}
void MultiscaleDeformableAttnPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* inputs, int32_t nbInputs,
nvinfer1::DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) PLUGIN_NOEXCEPT
{
// Check for valid input dimensions
PLUGIN_ASSERT(inputs[0].desc.dims.nbDims==4);
PLUGIN_ASSERT(inputs[1].desc.dims.nbDims==2);
PLUGIN_ASSERT(inputs[2].desc.dims.nbDims==1);
PLUGIN_ASSERT(inputs[3].desc.dims.nbDims==6);
PLUGIN_ASSERT(inputs[4].desc.dims.nbDims==5);
// Check M dimensions consistency
PLUGIN_ASSERT(inputs[0].desc.dims.d[2] == inputs[3].desc.dims.d[2]);
PLUGIN_ASSERT(inputs[0].desc.dims.d[2] == inputs[4].desc.dims.d[2]);
// Check L dimensions consistency
PLUGIN_ASSERT(inputs[1].desc.dims.d[0] == inputs[2].desc.dims.d[0]);
PLUGIN_ASSERT(inputs[1].desc.dims.d[0] == inputs[3].desc.dims.d[3]);
PLUGIN_ASSERT(inputs[1].desc.dims.d[0] == inputs[4].desc.dims.d[3]);
// Check P dimensions consistency
PLUGIN_ASSERT(inputs[3].desc.dims.d[4] == inputs[4].desc.dims.d[4]);
// Check Lq dimensions consistency
PLUGIN_ASSERT(inputs[3].desc.dims.d[1] == inputs[4].desc.dims.d[1]);
}
size_t MultiscaleDeformableAttnPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int32_t nbInputs,
nvinfer1::PluginTensorDesc const* outputs, int32_t nbOutputs) const PLUGIN_NOEXCEPT
{
return 0;
}
int32_t MultiscaleDeformableAttnPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workSpace,
cudaStream_t stream) PLUGIN_NOEXCEPT
{
int32_t const batch = inputDesc[0].dims.d[0];
int32_t spatial_size = inputDesc[0].dims.d[1];
int32_t num_heads = inputDesc[0].dims.d[2];
int32_t channels = inputDesc[0].dims.d[3];
int32_t num_levels = inputDesc[1].dims.d[0];
int32_t num_query = inputDesc[3].dims.d[1];
int32_t num_point = inputDesc[3].dims.d[4];
int32_t rc = 0;
if (inputDesc[0].type == nvinfer1::DataType::kFLOAT)
{
float const* value = static_cast<float const*>(inputs[0]);
int32_t const* spatialShapes = static_cast<int32_t const*>(inputs[1]);
int32_t const* levelStartIndex = static_cast<int32_t const*>(inputs[2]);
float const* samplingLoc = static_cast<float const*>(inputs[3]);
float const* attnWeight = static_cast<float const*>(inputs[4]);
float* output = static_cast<float*>(outputs[0]);
rc = ms_deform_attn_cuda_forward(stream, value, spatialShapes, levelStartIndex, samplingLoc, attnWeight, output,
batch, spatial_size, num_heads, channels, num_levels, num_query, num_point);
}
else if (inputDesc[0].type == nvinfer1::DataType::kHALF)
{
const __half* value = static_cast<const __half*>(inputs[0]);
int32_t const* spatialShapes = static_cast<int32_t const*>(inputs[1]);
int32_t const* levelStartIndex = static_cast<int32_t const*>(inputs[2]);
const __half* samplingLoc = static_cast<const __half*>(inputs[3]);
const __half* attnWeight = static_cast<const __half*>(inputs[4]);
__half* output = static_cast<__half*>(outputs[0]);
rc = ms_deform_attn_cuda_forward(stream, value, spatialShapes, levelStartIndex, samplingLoc, attnWeight, output,
batch, spatial_size, num_heads, channels, num_levels, num_query, num_point);
}
return rc;
}
void MultiscaleDeformableAttnPlugin::attachToContext(
cudnnContext* cudnnContext, cublasContext* cublasContext, nvinfer1::IGpuAllocator* gpuAllocator) PLUGIN_NOEXCEPT
{
}
void MultiscaleDeformableAttnPlugin::detachFromContext() PLUGIN_NOEXCEPT {}
// IPluginV2Ext Methods
nvinfer1::DataType MultiscaleDeformableAttnPlugin::getOutputDataType(
int32_t index, nvinfer1::DataType const* inputTypes, int32_t nbInputs) const PLUGIN_NOEXCEPT
{
return inputTypes[0];
}
// IPluginV2 Methods
char const* MultiscaleDeformableAttnPlugin::getPluginType() const PLUGIN_NOEXCEPT
{
return DMHA_NAME;
}
char const* MultiscaleDeformableAttnPlugin::getPluginVersion() const PLUGIN_NOEXCEPT
{
return DMHA_VERSION;
}
int32_t MultiscaleDeformableAttnPlugin::getNbOutputs() const PLUGIN_NOEXCEPT
{
return 1;
}
int32_t MultiscaleDeformableAttnPlugin::initialize() PLUGIN_NOEXCEPT
{
return 0;
}
void MultiscaleDeformableAttnPlugin::terminate() PLUGIN_NOEXCEPT {}
size_t MultiscaleDeformableAttnPlugin::getSerializationSize() const PLUGIN_NOEXCEPT
{
return 0;
}
void MultiscaleDeformableAttnPlugin::serialize(void* buffer) const PLUGIN_NOEXCEPT
{
}
void MultiscaleDeformableAttnPlugin::destroy() PLUGIN_NOEXCEPT
{
delete this;
}
void MultiscaleDeformableAttnPlugin::setPluginNamespace(char const* pluginNamespace) PLUGIN_NOEXCEPT
{
mNamespace = pluginNamespace;
}
char const* MultiscaleDeformableAttnPlugin::getPluginNamespace() const PLUGIN_NOEXCEPT
{
return mNamespace.c_str();
}
// Pluginv1 Creator
MultiscaleDeformableAttnPluginCreator::MultiscaleDeformableAttnPluginCreator()
{
mPluginAttributes.clear();
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
char const* MultiscaleDeformableAttnPluginCreator::getPluginName() const PLUGIN_NOEXCEPT
{
return DMHA_NAME;
}
char const* MultiscaleDeformableAttnPluginCreator::getPluginVersion() const PLUGIN_NOEXCEPT
{
return DMHA_VERSION;
}
nvinfer1::PluginFieldCollection const* MultiscaleDeformableAttnPluginCreator::getFieldNames() PLUGIN_NOEXCEPT
{
return &mFC;
}
IPluginV2* MultiscaleDeformableAttnPluginCreator::createPlugin(
char const* name, PluginFieldCollection const* fc) PLUGIN_NOEXCEPT
{
try
{
MultiscaleDeformableAttnPlugin* plugin = new MultiscaleDeformableAttnPlugin();
return plugin;
}
catch (const std::exception& e)
{
caughtError(e);
}
return nullptr;
}
IPluginV2* MultiscaleDeformableAttnPluginCreator::deserializePlugin(
char const* name, void const* serialData, size_t serialLength) PLUGIN_NOEXCEPT
{
try
{
auto plugin = new MultiscaleDeformableAttnPlugin(serialData, serialLength);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
catch (const std::exception& e)
{
caughtError(e);
}
return nullptr;
}
void MultiscaleDeformableAttnPluginCreator::setPluginNamespace(char const* pluginNamespace) PLUGIN_NOEXCEPT
{
mNamespace = pluginNamespace;
}
char const* MultiscaleDeformableAttnPluginCreator::getPluginNamespace() const PLUGIN_NOEXCEPT
{
return mNamespace.c_str();
}
} // namespace plugin
} // namespace nvinfer1