Skip to content

cpu: aarch64: prelu: add jit forward#5299

Open
CodrutIrimieARM wants to merge 1 commit into
uxlfoundation:mainfrom
CodrutIrimieARM:main
Open

cpu: aarch64: prelu: add jit forward#5299
CodrutIrimieARM wants to merge 1 commit into
uxlfoundation:mainfrom
CodrutIrimieARM:main

Conversation

@CodrutIrimieARM

Copy link
Copy Markdown

Description

This patch adds an AArch64 Xbyak JIT implementation for PReLU forward.

The new primitive supports SVE and ASIMD kernels for dense f32 and f16
tensors.

Motivation

The existing AArch64 PReLU path can use ACL, but the operation is simple enough
to benefit from a native oneDNN JIT implementation. PReLU computes:

dst = max(src, 0) + weights * min(src, 0)

Keeping the implementation in oneDNN also allows the primitive
selection logic to use the JIT path directly for supported AArch64 cases while
preserving ACL and reference fallback behavior for unsupported cases.

This patch keeps the initial scope intentionally small:

  • forward inference/training data path only
  • f32 and f16 tensors
  • default attributes
  • src and dst with matching dense memory descriptors
  • full, per-channel, and blocked channel broadcasts
  • SVE and ASIMD code generation through Xbyak_aarch64

Performance

Benchmarked with benchdnn PReLU tests, mode=P. The comparison is between the
new AArch64 JIT implementation and the existing ACL implementation.

The table below reports the average benchdnn avg_time across the four cases
for each data type and thread count.

Thread number Data type ACL mean time JIT mean time Ratio
1 f16 71.417575 ms 7.488973 ms 9.54x
1 f32 99.224925 ms 4.846303 ms 20.47x
2 f16 36.859750 ms 3.753510 ms 9.82x
2 f32 50.701650 ms 2.436695 ms 20.81x
4 f16 19.442625 ms 1.882810 ms 10.33x
4 f32 26.873900 ms 1.233500 ms 21.79x
8 f16 10.293645 ms 0.945506 ms 10.89x
8 f32 14.090725 ms 0.635043 ms 22.19x
16 f16 6.349348 ms 0.475104 ms 13.36x
16 f32 8.535005 ms 0.439290 ms 19.43x
32 f16 3.584250 ms 0.240958 ms 14.87x
32 f32 5.022210 ms 0.316064 ms 15.89x
64 f16 2.725655 ms 0.688183 ms 3.96x
64 f32 3.315585 ms 0.687887 ms 4.82x

Across all 56 measured cases:

  • Average JIT speedup by mean time: 14.32x
  • Average JIT speedup by min time: 16.01x
  • No benchmark failures or skips

The best measured case was f32 NCHW per-channel PReLU at 16 threads:

+----------------+--------------+
| Implementation | Avg time |
+----------------+--------------+
| JIT | 0.345023 ms |
| ACL | 10.2437 ms |
+----------------+--------------+

This is a 29.69x mean-time speedup over ACL for:

--mode=P --prelu --dir=FWD_D --sdt=f32:f32 --stag=abx:abx 32x128x53x53:1x128x1x1

At 64 threads the mean-time speedup is lower, likely due to thread overhead and
memory bandwidth saturation, but JIT remained faster than ACL for every measured
case.

Checklist

General

  • Do all unit and benchdnn tests (make test and make test_benchdnn_*) pass locally for each commit?
  • Have you formatted the code using clang-format?

Performance improvements

  • Have you submitted performance data that demonstrates performance improvements?

Add an AArch64 Xbyak JIT implementation for PReLU forward.

Support SVE and ASIMD kernels for dense f32 and f16 tensors.
Handle full, per-channel, and blocked channel weight broadcasts.
Register the JIT implementation before ACL and reference fallbacks.

Signed-off-by: Codrut Irimie <codrut.irimie@arm.com>
@CodrutIrimieARM CodrutIrimieARM requested review from a team as code owners June 11, 2026 15:50
@github-actions github-actions Bot added platform:cpu-aarch64 Codeowner: @oneapi-src/onednn-cpu-aarch64 component:common labels Jun 11, 2026
const Xbyak_aarch64::SReg s_min_ = s2;
const Xbyak_aarch64::SReg s_weights_ = s3;
const Xbyak_aarch64::SReg s_zero_ = s4;
};

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaking XByak internals to lists usually ends up in longer compilation times, and also leads to compilation of all files when a kernel needs to be updated. Usually, it's kept inside the .cpp and only a forward declaration is provided to store the kernel pointer.

If it is templated by ISA, there's a shim de-templated layer responsible for creation and proper dispatching. You may find an example in x64::jit_uni_layer_norm implementation.

@@ -0,0 +1,562 @@
/*******************************************************************************
* Copyright 2026 Arm Ltd. and affiliates

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How similar is this to the x64 version? If there are any substantial parts copied, then we should also copy the copyright header

mov(v_max, P_ALL_ONE / T_m, src);
fmax(v_max, P_ALL_ONE / T_m, v_zero);
mov(v_min, P_ALL_ONE / T_m, src);
fmin(v_min, P_ALL_ONE / T_m, v_zero);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SVE fmin has an immediate form for zero, so we don't need the temporary. Also I don't think we need the mov because src can be the second argument.

https://www.scs.stanford.edu/~zyedidia/arm64/fmin_z_p_zs.html

@jondea

jondea commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

If there are no missing or slower cases with this new impl, it would be good to remove ACL (#5145)

// The JIT kernel is shared across several PReLU broadcast patterns. The
// primitive descriptor classifies the memory descriptors once, then execute()
// uses this value to choose how to pass pointers and work sizes to the kernel.
enum class bcast {

@Sqvid Sqvid Jun 12, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to avoid the common enum and its associated functions?

enum class broadcasting_strategy_t {
// [n, c, d, h, w]
scalar, // [1, 1, 1, 1, 1] // Channel_shared
per_oc, // [1, c, 1, 1, 1] // Channel-wise
per_oc_spatial, // [1, c, 1, 1, 1] specific case for binary kernel nchw format
per_mb, // [n, 1, 1, 1, 1] // broadcast per batch
per_oc_d, // [a, b, c, d] -> [1, b, c, 1]; [n, g, oc/g, sp] --> [1, g, oc/g, 1] specific case for ncsp matmul reduction.
per_mb_spatial, // [n, 1, d, h, w] // Broadcast only channel
per_mb_w, // [n, 1, 1, 1, w] // Broadcast per batch and width
per_w, // [1, 1, 1, 1, w] // Broadcast per width
per_hw, // [1, 1, h, w] // Broadcast per height and width
shared_axes, // [n, 1, d, h, 1] // General case broadcast (any combination)
batch, // [1, c, d, h, w] // Broadcast only batch
spatial, // [n, c, 1, 1, 1] // Broadcast spatial dimensions
no_broadcast, // [n, c, d, h, w]
unsupported
};

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point, I'll use it instead

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

component:common platform:cpu-aarch64 Codeowner: @oneapi-src/onednn-cpu-aarch64

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants