|
1 | 1 | /******************************************************************************* |
2 | 2 | * Copyright 2019 Intel Corporation |
| 3 | +* Copyright 2026 Arm Ltd. and affiliates |
3 | 4 | * |
4 | 5 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | 6 | * you may not use this file except in compliance with the License. |
|
18 | 19 | #define COMMON_MATMUL_PD_HPP |
19 | 20 |
|
20 | 21 | #include <assert.h> |
| 22 | +#include <map> |
21 | 23 |
|
22 | 24 | #include "oneapi/dnnl/dnnl.h" |
23 | 25 |
|
@@ -199,10 +201,20 @@ struct matmul_pd_t : public primitive_desc_t { |
199 | 201 | virtual bool attr_scales_ok(const std::vector<int> &supported_args |
200 | 202 | = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}, |
201 | 203 | const std::vector<int> &supported_qmodes |
202 | | - = {quantization_mode::static_sazp}) const { |
| 204 | + = {quantization_mode::static_sazp}, |
| 205 | + const std::map<int, std::vector<int>> &extra_masks = {}) const { |
203 | 206 | const auto &scales = attr()->scales_; |
204 | 207 | if (scales.has_default_values()) return true; |
205 | 208 |
|
| 209 | + const auto extra_mask_ok = [&](int arg, int mask) { |
| 210 | + const auto it = extra_masks.find(arg); |
| 211 | + if (it != extra_masks.end()) { |
| 212 | + for (const auto &extra_mask : it->second) |
| 213 | + if (mask == extra_mask) return true; |
| 214 | + } |
| 215 | + return false; |
| 216 | + }; |
| 217 | + |
206 | 218 | bool ok = scales.has_default_values(supported_args); |
207 | 219 | for (int arg : supported_args) { |
208 | 220 | if (scales.has_default_values(arg)) { continue; } |
@@ -238,22 +250,31 @@ struct matmul_pd_t : public primitive_desc_t { |
238 | 250 | (mask & wei_qmask_K()), is_decompression); |
239 | 251 | } |
240 | 252 | } else if (arg == DNNL_ARG_SRC) { |
| 253 | + // Masks supported across all implementations. Implementation |
| 254 | + // specific masks can be passed through `extra_masks`. |
241 | 255 | ok = ok |
242 | | - && utils::one_of(mask, 0, src_qmask_K(), |
243 | | - src_qmask_M() + src_qmask_K(), |
244 | | - full_tensor_mask()); |
| 256 | + && (utils::one_of(mask, 0, src_qmask_K(), |
| 257 | + src_qmask_M() + src_qmask_K(), |
| 258 | + full_tensor_mask()) |
| 259 | + || extra_mask_ok(arg, mask)); |
245 | 260 | ok = ok |
246 | 261 | && IMPLICATION((mask & src_qmask_K()), |
247 | 262 | !scales.get(arg).has_default_groups()); |
248 | 263 | ok = ok |
249 | 264 | && IMPLICATION(!scales.get(arg).has_default_groups(), |
250 | 265 | scales.get_group(arg, 0) |
251 | 266 | && K() % scales.get_group(arg, 1) == 0); |
| 267 | + ok = ok |
| 268 | + && IMPLICATION(mask == src_qmask_M(), |
| 269 | + scales.get(arg).has_default_groups()); |
252 | 270 | } else if (arg == DNNL_ARG_DST) { |
| 271 | + // Masks supported across all implementations. Implementation |
| 272 | + // specific masks can be passed through `extra_masks`. |
253 | 273 | ok = ok |
254 | | - && utils::one_of(mask, 0, dst_qmask_N(), |
255 | | - dst_qmask_M() + dst_qmask_N(), |
256 | | - full_tensor_mask()); |
| 274 | + && (utils::one_of(mask, 0, dst_qmask_N(), |
| 275 | + dst_qmask_M() + dst_qmask_N(), |
| 276 | + full_tensor_mask()) |
| 277 | + || extra_mask_ok(arg, mask)); |
257 | 278 | ok = ok |
258 | 279 | && IMPLICATION(!scales.get(arg).has_default_groups(), |
259 | 280 | (M() % scales.get_group(arg, -2)) == 0 |
|
0 commit comments