Skip to content

Commit 21108fd

Browse files
authored
optimize phi::fusion::AttnMatMul (#78763)
* optimize phi::fusion::AttnMatMul * fix
1 parent 9f33b3d commit 21108fd

21 files changed

Lines changed: 245 additions & 258 deletions

paddle/phi/kernels/funcs/fast_ln_v1.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_v1_fwd_kernel(
6565
#pragma unroll
6666
for (int it = 0, col = c; it < LDGS; it++) {
6767
if (col < cols) {
68-
phi::Load<ScaleT, VecSize>(gamma_ptr + col * VecSize, &gamma[it]);
69-
phi::Load<ScaleT, VecSize>(beta_ptr + col * VecSize, &beta[it]);
68+
Load<ScaleT, VecSize>(gamma_ptr + col * VecSize, &gamma[it]);
69+
Load<ScaleT, VecSize>(beta_ptr + col * VecSize, &beta[it]);
7070
} else {
7171
gamma[it] = Vec_scale{};
7272
beta[it] = Vec_scale{};
@@ -80,7 +80,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_v1_fwd_kernel(
8080
#pragma unroll
8181
for (int it = 0, col = c; it < LDGS; it++) {
8282
if (col < cols) {
83-
phi::Load<T, VecSize>(
83+
Load<T, VecSize>(
8484
x_ptr + static_cast<int64_t>(row) * ELTS_PER_ROW + col * VecSize,
8585
&x[it]);
8686
} else {

paddle/phi/kernels/fusion/gpu/block_multi_head_attention_kernel.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -280,16 +280,16 @@ __global__ void DequantKernel(T* output,
280280
AlignedVector<T, VecSize> out_vec;
281281

282282
for (; idx < numel; idx += stride) {
283-
phi::Load<int32_t, VecSize>(input + idx, &in_vec);
284-
phi::Load<float, VecSize>(dequant_out_scale_data + col_id, &out_scale_vec);
283+
Load<int32_t, VecSize>(input + idx, &in_vec);
284+
Load<float, VecSize>(dequant_out_scale_data + col_id, &out_scale_vec);
285285

286286
#pragma unroll
287287
for (int i = 0; i < VecSize; ++i) {
288288
out_vec[i] =
289289
static_cast<T>(static_cast<float>(in_vec[i]) * out_scale_vec[i]);
290290
}
291291

292-
phi::Store<T, VecSize>(out_vec, output + idx);
292+
Store<T, VecSize>(out_vec, output + idx);
293293
}
294294
}
295295

paddle/phi/kernels/fusion/gpu/fused_attention_grad_kernel.cu

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -167,13 +167,13 @@ void FusedAttentionGradKernel(
167167

168168
const bool is_upscale_in_train =
169169
(dropout_implementation == "upscale_in_train");
170-
phi::fusion::DropoutParam dropout_param2(dropout_fix_seed,
171-
0,
172-
is_test,
173-
is_upscale_in_train,
174-
dropout_rate,
175-
nullptr,
176-
dropout_seed);
170+
fusion::DropoutParam dropout_param2(dropout_fix_seed,
171+
0,
172+
is_test,
173+
is_upscale_in_train,
174+
dropout_rate,
175+
nullptr,
176+
dropout_seed);
177177
const bool has_dropout = (dropout_param2.dropout_prob != 0.0f);
178178

179179
bool is_upscale_in_train_1 =
@@ -324,31 +324,31 @@ void FusedAttentionGradKernel(
324324
bool transB = transpose_qkv_wb ? false : true;
325325
bool compute_qkv_bias = qkv_bias_p ? true : false;
326326
auto layer_norm_compute =
327-
phi::fusion::AttnLayerNorm<T>(dev_ctx, epsilon, bsz_seq, dim_embed);
328-
auto qkv_compute = phi::fusion::AttnMatMul<T>(dev_ctx,
329-
transA,
330-
transB,
331-
bsz_seq,
332-
output_size,
333-
input_size,
334-
compute_qkv_bias);
335-
phi::fusion::AttnDropoutParam attn_dropout_param(is_test,
336-
attn_dropout_implementation,
337-
attn_dropout_rate,
338-
is_upscale_in_train_1,
339-
attn_dropout_fix_seed,
340-
attn_dropout_seed,
341-
seed_1);
342-
auto fmha_ref_compute = phi::fusion::FMHARef<T>(
327+
fusion::AttnLayerNorm<T>(dev_ctx, epsilon, bsz_seq, dim_embed);
328+
auto qkv_compute = fusion::AttnMatMul<T>(dev_ctx,
329+
transA,
330+
transB,
331+
bsz_seq,
332+
output_size,
333+
input_size,
334+
compute_qkv_bias);
335+
fusion::AttnDropoutParam attn_dropout_param(is_test,
336+
attn_dropout_implementation,
337+
attn_dropout_rate,
338+
is_upscale_in_train_1,
339+
attn_dropout_fix_seed,
340+
attn_dropout_seed,
341+
seed_1);
342+
auto fmha_ref_compute = fusion::FMHARef<T>(
343343
dev_ctx, batch_size, max_seq_len, num_head, dim_head, attn_dropout_param);
344344
output_size = hidden_size;
345345
transA = false;
346346
transB = false;
347347
bool compute_bias = false;
348348
// (b*s, num_head * dim_head) * (num_head * dim_head, dim_embed)
349-
auto out_linear_compute = phi::fusion::AttnMatMul<T>(
349+
auto out_linear_compute = fusion::AttnMatMul<T>(
350350
dev_ctx, transA, transB, bsz_seq, input_size, output_size, compute_bias);
351-
phi::fusion::FusedDropoutLayerNormHelper<T, uint8_t>
351+
fusion::FusedDropoutLayerNormHelper<T, uint8_t>
352352
fused_dropout_layernorm_helper(
353353
dev_ctx, bsz_seq, dim_embed, dropout_param2, ln_epsilon);
354354

paddle/phi/kernels/fusion/gpu/fused_attention_kernel.cu

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -139,13 +139,13 @@ void FusedAttentionKernel(const Context &dev_ctx,
139139

140140
const bool is_upscale_in_train =
141141
(dropout_implementation == "upscale_in_train");
142-
phi::fusion::DropoutParam dropout_param2(dropout_fix_seed,
143-
0,
144-
is_test,
145-
is_upscale_in_train,
146-
dropout_rate,
147-
nullptr,
148-
dropout_seed);
142+
fusion::DropoutParam dropout_param2(dropout_fix_seed,
143+
0,
144+
is_test,
145+
is_upscale_in_train,
146+
dropout_rate,
147+
nullptr,
148+
dropout_seed);
149149

150150
const bool has_dropout = (dropout_param2.dropout_prob != 0.0f);
151151

@@ -240,25 +240,25 @@ void FusedAttentionKernel(const Context &dev_ctx,
240240
int input_size = dim_embed;
241241

242242
auto layer_norm_compute =
243-
phi::fusion::AttnLayerNorm<T>(dev_ctx, epsilon, bsz_seq, dim_embed);
243+
fusion::AttnLayerNorm<T>(dev_ctx, epsilon, bsz_seq, dim_embed);
244244

245245
bool compute_bias = true;
246246
if (qkv_bias_p == nullptr) {
247247
compute_bias = false;
248248
}
249249
// (transA, transB, compute_bias) = (false, true, true)
250250
bool transB = transpose_qkv_wb ? false : true;
251-
auto qkv_compute = phi::fusion::AttnMatMul<T>(
251+
auto qkv_compute = fusion::AttnMatMul<T>(
252252
dev_ctx, false, transB, bsz_seq, output_size, input_size, compute_bias);
253253

254-
phi::fusion::AttnDropoutParam attn_dropout_param(is_test,
255-
attn_dropout_implementation,
256-
attn_dropout_rate,
257-
is_upscale_in_train_1,
258-
attn_dropout_fix_seed,
259-
attn_dropout_seed,
260-
seed_1);
261-
auto fmha_ref_compute = phi::fusion::FMHARef<T>(
254+
fusion::AttnDropoutParam attn_dropout_param(is_test,
255+
attn_dropout_implementation,
256+
attn_dropout_rate,
257+
is_upscale_in_train_1,
258+
attn_dropout_fix_seed,
259+
attn_dropout_seed,
260+
seed_1);
261+
auto fmha_ref_compute = fusion::FMHARef<T>(
262262
dev_ctx, batch_size, max_seq_len, num_head, dim_head, attn_dropout_param);
263263

264264
output_size = hidden_size;
@@ -268,9 +268,9 @@ void FusedAttentionKernel(const Context &dev_ctx,
268268
// which is actually the input size. While the input size is hidden size,
269269
// which is actually the output size. So for out linear, switch the
270270
// input size and output size.
271-
auto out_linear_compute = phi::fusion::AttnMatMul<T>(
271+
auto out_linear_compute = fusion::AttnMatMul<T>(
272272
dev_ctx, false, false, bsz_seq, input_size, output_size, false);
273-
phi::fusion::FusedDropoutLayerNormHelper<T, uint8_t>
273+
fusion::FusedDropoutLayerNormHelper<T, uint8_t>
274274
fused_dropout_layernorm_helper(
275275
dev_ctx, bsz_seq, dim_embed, dropout_param2, ln_epsilon);
276276

paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ __global__ void ActFFNGlu(const T *bias,
5050
load_func.template load<VecSize>(&src_vec2, index + hid_dim);
5151

5252
if (bias) {
53-
phi::Load<T, VecSize>(&bias[idx], &bias_vec1);
54-
phi::Load<T, VecSize>(&bias[idx + hid_dim], &bias_vec2);
53+
Load<T, VecSize>(&bias[idx], &bias_vec1);
54+
Load<T, VecSize>(&bias[idx + hid_dim], &bias_vec2);
5555
}
5656
#pragma unroll
5757
for (int j = 0; j < VecSize; j++) {
@@ -134,7 +134,7 @@ __global__ void BiasAct(const T *bias,
134134
int64_t linear_idx = row_idx * cols + col_idx;
135135
load_func.template load<VecSize>(&src_vec, linear_idx);
136136
if (bias) {
137-
phi::Load<T, VecSize>(&bias[col_idx], &bias_vec);
137+
Load<T, VecSize>(&bias[col_idx], &bias_vec);
138138
}
139139
#pragma unroll
140140
for (int j = 0; j < VecSize; j++) {

paddle/phi/kernels/fusion/gpu/fused_bias_dropout_residual_layer_norm_grad_kernel.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,15 +102,15 @@ void FusedBiasDropoutResidualLnGradKernel(
102102
bsz_seq *= input_x_dims[i];
103103
}
104104
int64_t dim_embed = input_x_dims[input_x_dims.size() - 1];
105-
phi::fusion::DropoutParam dropout_param(
105+
fusion::DropoutParam dropout_param(
106106
dropout_fix_seed,
107107
0,
108108
is_test,
109109
dropout_implementation == "upscale_in_train",
110110
dropout_rate,
111111
nullptr,
112112
dropout_seed);
113-
phi::fusion::FusedDropoutLayerNormHelper<T, uint8_t>
113+
fusion::FusedDropoutLayerNormHelper<T, uint8_t>
114114
fused_dropout_layernorm_helper(
115115
dev_ctx, bsz_seq, dim_embed, dropout_param, ln_epsilon);
116116
fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(

paddle/phi/kernels/fusion/gpu/fused_bias_dropout_residual_layer_norm_kernel.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,15 @@ void FusedBiasDropoutResidualLnKernel(const Context& dev_ctx,
6868
bsz_seq *= input_x_dims[i];
6969
}
7070
int dim_embed = input_x_dims[input_x_dims.size() - 1];
71-
phi::fusion::DropoutParam dropout_param(
71+
fusion::DropoutParam dropout_param(
7272
dropout_fix_seed,
7373
0,
7474
is_test,
7575
dropout_implementation == "upscale_in_train",
7676
dropout_rate,
7777
nullptr,
7878
dropout_seed);
79-
phi::fusion::FusedDropoutLayerNormHelper<T, uint8_t>
79+
fusion::FusedDropoutLayerNormHelper<T, uint8_t>
8080
fused_dropout_layernorm_helper(
8181
dev_ctx, bsz_seq, dim_embed, dropout_param, ln_epsilon);
8282
// output = layernorm(residual + dropout(input + bias))

paddle/phi/kernels/fusion/gpu/fused_dropout_act_bias.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,11 @@ __global__ void FusedActBias(Functor act,
162162
idx < elem_cnt;
163163
idx += step) {
164164
const int32_t col_idx = idx % cols;
165-
phi::Load<InType, VecSize>(&src[idx], &src_vec);
166-
phi::Load<float, VecSize>(&dequant_out_scale_data[col_idx],
167-
&dequant_out_scale_vec);
165+
Load<InType, VecSize>(&src[idx], &src_vec);
166+
Load<float, VecSize>(&dequant_out_scale_data[col_idx],
167+
&dequant_out_scale_vec);
168168
if (bias) {
169-
phi::Load<T, VecSize>(&bias[col_idx], &bias_vec);
169+
Load<T, VecSize>(&bias[col_idx], &bias_vec);
170170
}
171171
#pragma unroll
172172
for (int32_t unroll_idx = 0; unroll_idx < VecSize; unroll_idx++) {
@@ -194,7 +194,7 @@ __global__ void FusedActBias(Functor act,
194194
}
195195
}
196196
}
197-
phi::Store<OutType, VecSize>(out_vec, &dst[idx]);
197+
Store<OutType, VecSize>(out_vec, &dst[idx]);
198198
}
199199
}
200200

@@ -322,17 +322,17 @@ __global__ void FusedDropoutActGrad(Functor act_grad,
322322
LoadT src_vec;
323323
MaskLoadT mask_vec;
324324

325-
phi::Load<T, VecSize>(&dout[i], &dout_vec);
326-
phi::Load<MaskType, VecSize>(&mask[i], &mask_vec);
327-
phi::Load<T, VecSize>(&src[i], &src_vec);
325+
Load<T, VecSize>(&dout[i], &dout_vec);
326+
Load<MaskType, VecSize>(&mask[i], &mask_vec);
327+
Load<T, VecSize>(&src[i], &src_vec);
328328

329329
StoreT dx_vec;
330330
#pragma unroll
331331
for (int ii = 0; ii < VecSize; ii++) {
332332
T tmp = dout_vec[ii] * static_cast<T>(mask_vec[ii]) * factor;
333333
dx_vec[ii] = tmp * act_grad.UseOut(src_vec[ii]);
334334
}
335-
phi::Store<T, VecSize>(dx_vec, &dx[i]);
335+
Store<T, VecSize>(dx_vec, &dx[i]);
336336
}
337337
}
338338

@@ -376,10 +376,10 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void FusedDropoutActBiasGrad(
376376
LoadT bias_vec;
377377
MaskLoadT mask_vec;
378378

379-
phi::Load<T, VecSize>(&dout[index], &dout_vec);
380-
phi::Load<T, VecSize>(&src[index], &src_vec);
381-
phi::Load<MaskType, VecSize>(&mask[index], &mask_vec);
382-
phi::Load<T, VecSize>(&bias[col_id * VecSize], &bias_vec);
379+
Load<T, VecSize>(&dout[index], &dout_vec);
380+
Load<T, VecSize>(&src[index], &src_vec);
381+
Load<MaskType, VecSize>(&mask[index], &mask_vec);
382+
Load<T, VecSize>(&bias[col_id * VecSize], &bias_vec);
383383

384384
StoreT dx_vec;
385385
#pragma unroll
@@ -390,7 +390,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void FusedDropoutActBiasGrad(
390390
dx_vec[i] = val;
391391
tmp_sum[i] += val;
392392
}
393-
phi::Store<T, VecSize>(dx_vec, &dx[index]);
393+
Store<T, VecSize>(dx_vec, &dx[index]);
394394
}
395395
}
396396

paddle/phi/kernels/fusion/gpu/fused_feedforward_grad_kernel.cu

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -77,19 +77,19 @@ void FFNGrad(const GPUContext& dev_ctx,
7777
const int bsz_seq,
7878
const int d_model,
7979
const int dim_feedforward,
80-
const phi::fusion::DropoutParam& dropout_param1,
81-
const phi::fusion::DropoutParam& dropout_param2,
80+
const fusion::DropoutParam& dropout_param1,
81+
const fusion::DropoutParam& dropout_param2,
8282
const std::string& act_method,
8383
const bool pre_layer_norm,
8484
const float epsilon1,
8585
const float epsilon2,
8686
const bool add_residual,
8787
const int ring_id) {
88-
phi::fusion::FusedDropoutLayerNormHelper<T, uint8_t> pre_layernorm_helper(
88+
fusion::FusedDropoutLayerNormHelper<T, uint8_t> pre_layernorm_helper(
8989
bsz_seq, d_model, epsilon1);
90-
phi::fusion::FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper(
90+
fusion::FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper(
9191
dev_ctx, bsz_seq, dim_feedforward, dropout_param1);
92-
phi::fusion::FusedDropoutLayerNormHelper<T, uint8_t>
92+
fusion::FusedDropoutLayerNormHelper<T, uint8_t>
9393
fused_dropout_layernorm_helper(
9494
dev_ctx, bsz_seq, d_model, dropout_param2, epsilon2);
9595

@@ -283,20 +283,20 @@ void FusedFeedForwardGradKernel(const Context& dev_ctx,
283283
bool is_upscale_in_train1 = dropout1_implementation == "upscale_in_train";
284284
bool is_upscale_in_train2 = dropout2_implementation == "upscale_in_train";
285285

286-
phi::fusion::DropoutParam dropout_param1(dropout1_fix_seed,
287-
0,
288-
is_test,
289-
is_upscale_in_train1,
290-
dropout1_prob,
291-
nullptr,
292-
dropout1_seed_val);
293-
phi::fusion::DropoutParam dropout_param2(dropout2_fix_seed,
294-
0,
295-
is_test,
296-
is_upscale_in_train2,
297-
dropout2_prob,
298-
nullptr,
299-
dropout2_seed_val);
286+
fusion::DropoutParam dropout_param1(dropout1_fix_seed,
287+
0,
288+
is_test,
289+
is_upscale_in_train1,
290+
dropout1_prob,
291+
nullptr,
292+
dropout1_seed_val);
293+
fusion::DropoutParam dropout_param2(dropout2_fix_seed,
294+
0,
295+
is_test,
296+
is_upscale_in_train2,
297+
dropout2_prob,
298+
nullptr,
299+
dropout2_seed_val);
300300

301301
dev_ctx.template Alloc<T>(d_x, d_x->numel() * sizeof(T));
302302
if (d_ln1_scale) {

0 commit comments

Comments
 (0)