Skip to content

Commit 1a668a4

Browse files
aarch64: matmul: support for per_m scales and bf16 dst in jit int8
1 parent 2f9f145 commit 1a668a4

8 files changed

Lines changed: 236 additions & 85 deletions

File tree

src/common/matmul.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/*******************************************************************************
22
* Copyright 2019 Intel Corporation
3+
* Copyright 2026 Arm Ltd. and affiliates
34
*
45
* Licensed under the Apache License, Version 2.0 (the "License");
56
* you may not use this file except in compliance with the License.
@@ -322,7 +323,7 @@ status_t matmul_attr_check(const matmul_desc_t &desc, const engine_t *engine,
322323
const int mask_src = sc.get_mask(DNNL_ARG_SRC);
323324

324325
VCHECK_MATMUL_UNIMPL(
325-
utils::one_of(mask_src, 0, src_qmask_K,
326+
utils::one_of(mask_src, 0, src_qmask_M, src_qmask_K,
326327
src_qmask_M + src_qmask_K, full_tensor_mask),
327328
VERBOSE_UNSUPPORTED_SCALES_CFG);
328329

src/common/matmul_pd.hpp

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/*******************************************************************************
22
* Copyright 2019 Intel Corporation
3+
* Copyright 2026 Arm Ltd. and affiliates
34
*
45
* Licensed under the Apache License, Version 2.0 (the "License");
56
* you may not use this file except in compliance with the License.
@@ -18,6 +19,7 @@
1819
#define COMMON_MATMUL_PD_HPP
1920

2021
#include <assert.h>
22+
#include <map>
2123

2224
#include "oneapi/dnnl/dnnl.h"
2325

@@ -199,10 +201,20 @@ struct matmul_pd_t : public primitive_desc_t {
199201
virtual bool attr_scales_ok(const std::vector<int> &supported_args
200202
= {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST},
201203
const std::vector<int> &supported_qmodes
202-
= {quantization_mode::static_sazp}) const {
204+
= {quantization_mode::static_sazp},
205+
const std::map<int, std::vector<int>> &extra_masks = {}) const {
203206
const auto &scales = attr()->scales_;
204207
if (scales.has_default_values()) return true;
205208

209+
const auto extra_mask_ok = [&](int arg, int mask) {
210+
const auto it = extra_masks.find(arg);
211+
if (it != extra_masks.end()) {
212+
for (const auto &extra_mask : it->second)
213+
if (mask == extra_mask) return true;
214+
}
215+
return false;
216+
};
217+
206218
bool ok = scales.has_default_values(supported_args);
207219
for (int arg : supported_args) {
208220
if (scales.has_default_values(arg)) { continue; }
@@ -238,22 +250,31 @@ struct matmul_pd_t : public primitive_desc_t {
238250
(mask & wei_qmask_K()), is_decompression);
239251
}
240252
} else if (arg == DNNL_ARG_SRC) {
253+
// Masks supported across all implementations. Implementation
254+
// specific masks can be passed through `extra_masks`.
241255
ok = ok
242-
&& utils::one_of(mask, 0, src_qmask_K(),
243-
src_qmask_M() + src_qmask_K(),
244-
full_tensor_mask());
256+
&& (utils::one_of(mask, 0, src_qmask_K(),
257+
src_qmask_M() + src_qmask_K(),
258+
full_tensor_mask())
259+
|| extra_mask_ok(arg, mask));
245260
ok = ok
246261
&& IMPLICATION((mask & src_qmask_K()),
247262
!scales.get(arg).has_default_groups());
248263
ok = ok
249264
&& IMPLICATION(!scales.get(arg).has_default_groups(),
250265
scales.get_group(arg, 0)
251266
&& K() % scales.get_group(arg, 1) == 0);
267+
ok = ok
268+
&& IMPLICATION(mask == src_qmask_M(),
269+
scales.get(arg).has_default_groups());
252270
} else if (arg == DNNL_ARG_DST) {
271+
// Masks supported across all implementations. Implementation
272+
// specific masks can be passed through `extra_masks`.
253273
ok = ok
254-
&& utils::one_of(mask, 0, dst_qmask_N(),
255-
dst_qmask_M() + dst_qmask_N(),
256-
full_tensor_mask());
274+
&& (utils::one_of(mask, 0, dst_qmask_N(),
275+
dst_qmask_M() + dst_qmask_N(),
276+
full_tensor_mask())
277+
|| extra_mask_ok(arg, mask));
257278
ok = ok
258279
&& IMPLICATION(!scales.get(arg).has_default_groups(),
259280
(M() % scales.get_group(arg, -2)) == 0

src/cpu/aarch64/matmul/jit_int8_kernel_types.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,17 @@ struct brg_int8_t {
6666
int na, nb;
6767
int m_tail, n_tail, k_tail;
6868
int is_m_tail, is_k_tail, is_n_tail, is_zp_cal;
69+
data_type_t dst_dt;
70+
const int acc_dt_sz = sizeof(float);
6971
int dst_dt_sz;
7072
bool is_s8, is_u8_s8;
7173
bool is_bias;
7274
bool with_scales;
75+
bool with_src_scales;
76+
bool with_wei_scales;
7377
bool with_dst_scales;
7478
bool is_oc_scales;
79+
bool is_per_m_scales = false;
7580
jit_int8_broadcast_t zp_type_a = jit_int8_broadcast_t::none;
7681
jit_int8_broadcast_t zp_type_b = jit_int8_broadcast_t::none;
7782
jit_int8_broadcast_t zp_type_c = jit_int8_broadcast_t::none;
@@ -83,8 +88,10 @@ struct brg_int8_t {
8388

8489
struct call_params_t {
8590
const uint8_t *src, *wei;
86-
float *dst;
91+
uint8_t *dst;
8792
const float *bias, *scales, *dst_scales;
93+
const float *src_scales; // optional per-row src logical-M scales
94+
const float *wei_scales; // optional kernel-ready weight scales
8895
dim_t M, K, N;
8996
char *buf_B_ptr_;
9097
int *na, *nb;

0 commit comments

Comments
 (0)