cpu: aarch64: prelu: add jit forward#5299
Conversation
Add an AArch64 Xbyak JIT implementation for PReLU forward. Support SVE and ASIMD kernels for dense f32 and f16 tensors. Handle full, per-channel, and blocked channel weight broadcasts. Register the JIT implementation before ACL and reference fallbacks. Signed-off-by: Codrut Irimie <codrut.irimie@arm.com>
| const Xbyak_aarch64::SReg s_min_ = s2; | ||
| const Xbyak_aarch64::SReg s_weights_ = s3; | ||
| const Xbyak_aarch64::SReg s_zero_ = s4; | ||
| }; |
There was a problem hiding this comment.
Leaking XByak internals to lists usually ends up in longer compilation times, and also leads to compilation of all files when a kernel needs to be updated. Usually, it's kept inside the .cpp and only a forward declaration is provided to store the kernel pointer.
If it is templated by ISA, there's a shim de-templated layer responsible for creation and proper dispatching. You may find an example in x64::jit_uni_layer_norm implementation.
| @@ -0,0 +1,562 @@ | |||
| /******************************************************************************* | |||
| * Copyright 2026 Arm Ltd. and affiliates | |||
There was a problem hiding this comment.
How similar is this to the x64 version? If there are any substantial parts copied, then we should also copy the copyright header
| mov(v_max, P_ALL_ONE / T_m, src); | ||
| fmax(v_max, P_ALL_ONE / T_m, v_zero); | ||
| mov(v_min, P_ALL_ONE / T_m, src); | ||
| fmin(v_min, P_ALL_ONE / T_m, v_zero); |
There was a problem hiding this comment.
SVE fmin has an immediate form for zero, so we don't need the temporary. Also I don't think we need the mov because src can be the second argument.
https://www.scs.stanford.edu/~zyedidia/arm64/fmin_z_p_zs.html
|
If there are no missing or slower cases with this new impl, it would be good to remove ACL (#5145) |
| // The JIT kernel is shared across several PReLU broadcast patterns. The | ||
| // primitive descriptor classifies the memory descriptors once, then execute() | ||
| // uses this value to choose how to pass pointers and work sizes to the kernel. | ||
| enum class bcast { |
There was a problem hiding this comment.
Is there a reason to avoid the common enum and its associated functions?
oneDNN/src/common/broadcast_strategy.hpp
Lines 31 to 47 in df093fe
There was a problem hiding this comment.
That's a good point, I'll use it instead
Description
This patch adds an AArch64 Xbyak JIT implementation for PReLU forward.
The new primitive supports SVE and ASIMD kernels for dense f32 and f16
tensors.
Motivation
The existing AArch64 PReLU path can use ACL, but the operation is simple enough
to benefit from a native oneDNN JIT implementation. PReLU computes:
dst = max(src, 0) + weights * min(src, 0)
Keeping the implementation in oneDNN also allows the primitive
selection logic to use the JIT path directly for supported AArch64 cases while
preserving ACL and reference fallback behavior for unsupported cases.
This patch keeps the initial scope intentionally small:
Performance
Benchmarked with benchdnn PReLU tests, mode=P. The comparison is between the
new AArch64 JIT implementation and the existing ACL implementation.
The table below reports the average benchdnn
avg_timeacross the four casesfor each data type and thread count.
Across all 56 measured cases:
The best measured case was f32 NCHW per-channel PReLU at 16 threads:
+----------------+--------------+
| Implementation | Avg time |
+----------------+--------------+
| JIT | 0.345023 ms |
| ACL | 10.2437 ms |
+----------------+--------------+
This is a 29.69x mean-time speedup over ACL for:
At 64 threads the mean-time speedup is lower, likely due to thread overhead and
memory bandwidth saturation, but JIT remained faster than ACL for every measured
case.
Checklist
General
make testandmake test_benchdnn_*) pass locally for each commit?Performance improvements