Skip to content

feat: add Metal GPU backend for Apple Silicon#7215

Open
discordwell wants to merge 11 commits intolightgbm-org:masterfrom
discordwell:feature/metal-backend
Open

feat: add Metal GPU backend for Apple Silicon#7215
discordwell wants to merge 11 commits intolightgbm-org:masterfrom
discordwell:feature/metal-backend

Conversation

@discordwell
Copy link
Copy Markdown

Summary

Adds a native Metal compute shader backend for GPU-accelerated histogram construction on Apple Silicon Macs, addressing the gap left by the deprecated OpenCL backend which crashes on Apple Silicon (#6189).

  • New device_type="metal" option
  • Three Metal compute kernels ported from OpenCL (histogram 16/64/256 bins)
  • Multi-workgroup support with CPU-side sub-histogram reduction
  • Pre-compiled metallib at build time; runtime .metal compilation as fallback
  • 21-test suite, documentation, build-python.sh integration

Design

Follows the OpenCL GPUTreeLearner architecture: MetalTreeLearner extends SerialTreeLearner, accelerating only histogram construction on GPU. Split finding and tree construction remain on CPU. Apple Silicon's unified memory eliminates the pinned-buffer / async-DMA complexity of the OpenCL path.

Multi-workgroup histogram accumulation writes sub-histograms to device memory; a CPU-side reduction merges them with the necessary separated-to-interleaved format conversion and feature-order reversal (matching within_kernel_reduction). This avoids Metal's lack of reliable cross-threadgroup synchronization within a single dispatch.

Benchmarks (Apple M4 Max)

Tested against single-threaded CPU with num_leaves=63, max_bin=63:

Dataset Rows Features Metal vs CPU
Higgs 10M 28 2.94x faster
Epsilon-like 400K 2,000 11.3x faster
Bosch-like 1.2M 968 8.6x faster
Synthetic 100K 50 0.54x (dispatch overhead dominates)

Performance characteristics match the existing OpenCL/CUDA backends: GPU acceleration requires large or wide datasets (≥300K rows) where histogram computation dominates over per-dispatch overhead.

Test results

21 passed in 2.5s (Apple M4 Max)
  • Binary classification, regression, multiclass
  • All 3 histogram kernel variants (max_bin 15/63/255)
  • Small + large dataset multi-workgroup paths
  • Bagging, constant hessian, gpu_use_dp override
  • Linear tree, refit, training data reset
  • Pre-compiled metallib discovery

Build

cmake -B build -S . -DUSE_METAL=ON
cmake --build build -j

Requires macOS with Apple Silicon. Metal Toolchain (xcodebuild -downloadComponent MetalToolchain) produces default.metallib at build time; without it, kernels compile from .metal source at runtime.

Files changed

New (5 files, ~3,300 lines):

  • src/treelearner/metal_tree_learner.h — class declaration with opaque void* for ObjC types
  • src/treelearner/metal_tree_learner.mm — Objective-C++ implementation
  • src/treelearner/metal/histogram{16,64,256}.metal — Metal compute kernels
  • tests/python_package_test/test_metal.py — 21-test suite

Modified (15 files, ~770 lines):

  • CMakeLists.txtUSE_METAL option, metallib compilation, framework linking
  • include/LightGBM/config.h, src/io/config.cppdevice_type="metal"
  • src/treelearner/tree_learner.cpp — factory method
  • build-python.sh--metal flag, metallib packaging
  • docs/Installation-Guide.rst, docs/Parameters.rst — documentation
  • Template instantiation files (4) + header includes (2)

Known limitations

  • FP32 only (Metal has no double precision); gpu_use_dp is warned and disabled
  • macOS only, Apple Silicon only
  • Per-dispatch overhead (~15ms) makes GPU slower than CPU for datasets under ~300K rows — same as OpenCL/CUDA

Test plan

  • Correctness vs CPU across all objectives and kernel variants
  • Multi-workgroup reduction at scale (up to 10M rows)
  • Benchmark on standard ML datasets (Higgs, Epsilon-like, Bosch-like)
  • Bagging, constant hessian, linear tree, refit
  • USE_METAL=OFF doesn't affect other platforms
  • Zero compiler warnings (Metal shader compiler + Clang)
  • Pre-compiled metallib loading + runtime fallback

discordwell and others added 10 commits March 31, 2026 05:32
Add a native Metal compute shader backend for GPU-accelerated histogram
construction on Apple Silicon Macs. This addresses the gap left by the
deprecated OpenCL backend, which crashes on Apple Silicon (lightgbm-org#6189).

The Metal backend follows the same architecture as the existing OpenCL
GPU backend (histogram-only acceleration; split finding stays on CPU),
but is significantly simpler due to Apple Silicon's unified memory —
no pinned buffers, no async PCIe transfers needed.

New files:
- metal_tree_learner.h/.mm: MetalTreeLearner extending SerialTreeLearner
- metal/histogram{16,64,256}.metal: Metal compute kernels ported from OCL

Modified:
- CMakeLists.txt: USE_METAL option, Metal framework linking, metallib build
- config.h/config.cpp: device_type="metal" support
- tree_learner.cpp: factory method for Metal learner
- parallel/linear learner files: MetalTreeLearner template instantiations

Build: cmake -DUSE_METAL=ON .. (macOS only)
Usage: lgb.train({..., 'device': 'metal'})

Tested on Apple M4 Max with datasets up to 50K rows, all three kernel
variants (max_bin 16/64/256). Max prediction diff vs CPU: <5e-6.

Known limitation: currently uses single workgroup per feature (POWER=0)
due to Metal's lack of cross-threadgroup memory synchronization within
a single dispatch. Multi-workgroup support will follow via a two-pass
reduction kernel.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Add reduce_histogram{16,64,256} kernels for future multi-workgroup support
- Strip within-kernel cross-threadgroup sync from multi-WG path (unreliable on Metal)
- Keep single-workgroup (POWER=0) as default for guaranteed correctness
- Fix buffer index mismatch: hessians=6, const_hessian=7, output=8, sync=9, hist=10
- Add comprehensive test suite (test_metal.py): 14 tests covering binary, regression,
  multiclass, all 3 kernel variants, scalability to 10K rows, bagging, constant hessian
- All 14 tests pass on Apple M4 Max

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Change the multi-workgroup sub-histogram write format from separated
(grad block + hess block per feature) to interleaved (grad/hess pairs
per bin), matching the single-workgroup output format. This enables
simple element-wise CPU-side reduction without layout conversion.

Multi-workgroup (POWER > 0) remains disabled pending kernel debugging —
the histogram accumulation produces incorrect values with multiple
threadgroups. The infrastructure is in place for when this is fixed:
- Interleaved sub-histogram writes in all 3 kernel variants
- CPU-side reduction with blit buffer clear between iterations
- pending_exp_workgroups_ tracking for conditional reduction path

All 14 Metal tests pass with single-workgroup (POWER=0).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Remove unconditional debug logging from multi-workgroup path
- Fix feature_masks address space: use `device const` consistently
  across all 3 kernel variants (avoids potential stale constant cache)
- Remove unused reduce_histogram kernels (dead code, POWER always 0)
- Fix log message tone ("This is the Metal trainer!!" -> professional)
- Make kMaxLogWorkgroupsPerFeature static constexpr
- Remove unused imports (time, sys) from test_metal.py
- Remove unused variable (out_elems) and redundant bridge

Zero compiler warnings. 14/14 Metal tests pass.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Root cause of multi-workgroup failure: the kernel writes features in
reverse order within each feature4 tuple (j=0 in threadgroup = .w component
= last feature in the group), matching within_kernel_reduction's reversal.
The CPU-side reduction was not reversing, causing feature histograms to
be assigned to wrong feature groups.

Fix: reverse feature index in the CPU reduction loop for histogram256
and histogram64. Histogram16 has a different layout and also needs
reversal.

Also in this commit:
- Remove all debug comparison code (METAL_DEBUG_COMPARE_POWER)
- Add documentation to Installation-Guide.rst (Build Metal Version)
- Add build-python.sh support for metallib packaging
- Improve CMakeLists.txt: metallib install dir for Python package builds
- Expand test suite: 21 tests including multi-WG large dataset variants,
  linear tree, refit, and training data reset
- Skip metallib test when Metal Toolchain not installed

20 passed, 1 skipped (metallib packaging test) in 2.3s.
All dataset sizes 100-50000 produce correct results with multi-workgroup.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Remove unused `lsize` from within_kernel_reduction helper functions
in histogram256.metal and histogram64.metal, and unused `BANK_BITS`
constant from histogram16.metal. Zero warnings from Metal compiler.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Remove METAL_DEBUG define and all #if METAL_DEBUG blocks
- Remove unused #include <cstdio>
- Remove unused sync_counters buffer (kernel param, host allocation,
  destructor cleanup, buffer binding) — renumber hist_buf_base to
  buffer(9)
- Remove unused gsize [[threads_per_grid]] param from histogram16/64
- Replace stale TODO comments with design notes explaining CPU-side
  cross-workgroup reduction
- Add --metal flag to build-python.sh for parity with --gpu/--cuda/--rocm

Zero compiler warnings. 21/21 Metal tests pass.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Gather gradients/hessians directly into Metal shared buffer contents
  instead of intermediate vector + memcpy (BeforeTrain, BeforeFindBestSplit,
  ConstructMetalHistogramsAsync)
- Split WaitAndGetHistograms into WaitForGPU + ProcessHistogramResults
  to allow overlapping smaller leaf CPU reduction with larger leaf GPU
- Restructure ConstructHistograms for pipelined execution:
  1. Launch smaller GPU → 2. CPU sparse (overlaps) → 3. Wait GPU →
  4. Launch larger GPU → 5. Process smaller results (overlaps) →
  6. CPU sparse larger (overlaps) → 7. Wait larger GPU

Benchmark (Apple M4 Max, 50 features, num_leaves=63, 10 rounds):
  n=10K:   Metal 0.22x CPU — dispatch overhead dominates
  n=100K:  Metal 0.54x CPU — transitional
  n=500K:  Metal 1.16x CPU — GPU wins
  n=1M:    Metal 1.69x CPU — GPU significantly faster

21/21 tests pass.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…uard

- Use ordered_gradients_.reserve() instead of resize() to avoid
  unnecessary zero-initialization of ~262K elements (match OpenCL)
- Add METAL_DEBUG_COMPARE mode (commented out) for CPU-vs-GPU histogram
  cross-validation, matching OpenCL's GPU_DEBUG_COMPARE pattern
- Add num_dense_feature_groups_ guard in MetalHistogram() for safety

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Member

@jameslamb jameslamb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your interest in LightGBM and for the interesting proposal!!!

I see the following at https://developer.apple.com/opencl/

If you’re using OpenCL, which was deprecated in macOS 10.14, for GPU-based computational tasks in your Mac app, we recommend that you transition to Metal and Metal Performance Shaders for access to a wider range of capabilities.

Seems like pretty good evidence to me that we should consider this!


I left an initial set of small suggestions. Some other larger questions:

  • can / should we compile this support directly into LightGBM's macOS wheels? Or maybe just the arm64 ones? What would the impact on binary size be?
  • can we test this in CI? Does GitHub Actions or some other service have runners with Metal cards?

Since I see you've used Claude to generate this, a few questions:

  • did you yourself manually review this code before opening this? Or are we reviewing something you've directly pushed from an LLM without review?
  • Can you share the benchmarking code you used? Why did you benchmark only against single-threaded CPU? What do those benchmarks look like if you use multithreading? Knowing that training is faster on CPU w/ multithreading + OpenMP, and by how much, would help us to understand how much effort it is worthwhile to invest in this.
  • did you actually run all the tests and benchmarks in the PR description, or did Claude just write that?

And some expectation-setting:

  • when we ask questions, you should answer them in your own words (do not have LLMs respond)
  • you are expected to understand and be able to explain every change in the PR if asked
  • this project is maintained by busy volunteers and you've asked us to review a huge amount of code here (and then to maintain it)... be patient, respond to questions when we ask, and do as much local testing as you can to minimize the number of review cycles

~~~~~~~~~~~~~~~~~~~

The Metal version of LightGBM (``device_type=metal``) uses Apple Metal for histogram construction on macOS with Apple Silicon.
It supports serial training, linear trees, refit, and distributed training with the tree learners that LightGBM already exposes outside the C API.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
It supports serial training, linear trees, refit, and distributed training with the tree learners that LightGBM already exposes outside the C API.
It supports serial training, linear trees, refit, and distributed training with the tree learners that LightGBM already exposes via its C API.

I think here you meant to talk about what is exposed via the C API, not "outside of" it, right?


After compilation the executable and ``.dylib`` files will be in ``LightGBM/``.

When the Metal toolchain is available at build time, LightGBM also produces ``default.metallib`` and installs or packages it beside ``lib_lightgbm``.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify this (in a comment here, not in the docs yet). There will be a file installed at like /usr/lib/default.metallib or something similar?

  • what does that file do?
  • links to documentation where I could learn more?
  • why does it have such a generic name, shouldn't it be lib_lightgbm.metallib or something like that?

@@ -0,0 +1,328 @@
# coding: utf-8
"""Tests for Metal GPU backend on Apple Silicon."""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please, don't introduce any new test files here. Put these tests in the existing Python files (e.g. tests/python_package_test/test_sklearn.py for anything using the scikit-learn estimators).

data = lgb.Dataset(X, label=y)
params = {"device": "metal", "verbose": -1, "num_leaves": 2, "num_iterations": 1}
lgb.train(params, data)
except lgb.basic.LightGBMError as e:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very fragile way to test for this support. It means that changes that catastrophically break the Metal support might not result in failing tests, only skipped tests!

Please use a more direct mechanism, like pytest markers or the presence/absence of an environment variable.

_skip_if_not_metal()


class TestMetalBasic:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use pytest-style test functions here, not unittest-style classes. Please match that existing style.


#include "serial_tree_learner.h"

#ifdef USE_METAL
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#ifdef USE_METAL
#ifdef LGBM_USE_METAL

I know we don't do a good job of this in other parts of the library but I'd like to avoid making it worse... let's please prefix this define with LGBM_ everywhere. Generic names like USE_METAL are more likely to conflict when LightGBM is included with other projects. Especially given that the CMake changes are proposing using the global add_definitions() (that's fine for this PR, it can be refactored along with all the other uses of add_definitions() in #6774).

@@ -0,0 +1,1467 @@
/*!
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add .mm files to the OpenMP linting here:

get_omp_pragmas_without_num_threads() {
grep \
-n \
-R \
--include='*.c' \
--include='*.cc' \
--include='*.cpp' \
--include='*.h' \


// Compare GPU histogram with CPU histogram, useful for debugging GPU code problems.
// Uncomment the #define below to enable.
// #define METAL_DEBUG_COMPARE
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this was probably just copied from here:

// Compare GPU histogram with CPU histogram, useful for debugging GPU code problem
// #define GPU_DEBUG_COMPARE
#ifdef GPU_DEBUG_COMPARE

Did you intentionally include this and actually use it? Or is it just something Claude copied when you asked it to translate the OpenCL-based implementation?

This is a huge PR and I'm looking for opportunities to simplify it... unless you can make a strong case for this debugging code existing, please remove it.

Log::Fatal("Currently cuda version only supports training on a single machine.");
}
} else if (device_type == std::string("metal")) {
if (learner_type == std::string("serial")) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you actually test + intend to include the distributed training support here? Or is this just an artifact of you asking Claude to translate the OpenCL-based implementation to Metal?

Unless you can make a strong case for including this and explain how you used it, please remove all the distributed training code and just make this only support serial tree learning for now, as the CUDA build does a few lines above this. That'll simplify the PR.

If someone really wants to try something like distributed training across the Metal GPUs on multiple Mac Minis or something, they can open a feature request and we can try to implement this then.

Comment on lines +327 to +328
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])

Remove stuff like this from test files here. We don't support invoking individual test files like python test_metal.py ... this project exclusively uses pytest to invoke tests.

- Rename USE_METAL -> LGBM_USE_METAL across CMake, C++, shell, and docs
- Simplify CMakeLists.txt: require Metal compiler at configure time (like
  nvcc for CUDA), remove runtime compilation fallback
- Remove METAL_DEBUG_COMPARE debug scaffolding from metal_tree_learner.mm
- Restrict Metal to serial (+ linear) tree learning only, matching CUDA
  pattern; remove parallel template instantiations
- Rename metallib output from default.metallib to lib_lightgbm.metallib
- Move metallib asset copy from build-python.sh to CMake install rule
- Add .mm to OMP pragma linting in check-omp-pragmas.sh
- Convert test_metal.py: flat pytest functions, skipif with env var
  LIGHTGBM_TEST_METAL=1, remove if __name__ block
- Fix doc text: "outside the C API" -> "via its C API"

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants