Skip to content

Commit b2bfac8

Browse files
authored
Bound audio-analysis CPU usage and silence NNPACK spam on ARM (#4257)
1 parent 9058d49 commit b2bfac8

9 files changed

Lines changed: 72 additions & 23 deletions

File tree

music_assistant/controllers/streams/audio_analysis.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from music_assistant.helpers.api import api_command
2929
from music_assistant.helpers.datetime import local_clock_time_to_utc
3030
from music_assistant.helpers.json import json_dumps, json_loads
31+
from music_assistant.helpers.util import is_arm
3132
from music_assistant.models.audio_analysis import AudioAnalysisData
3233
from music_assistant.models.audio_analysis_provider import AudioAnalysisProvider
3334
from music_assistant.models.music_provider import MusicProvider
@@ -136,7 +137,10 @@ def __init__(self, streams: StreamsController) -> None:
136137
self.logger = self.mass.logger.getChild("audio_analysis")
137138
self._active_sessions: dict[str, set[str]] = {}
138139
self._workers: dict[str, asyncio.Task[None]] = {}
139-
self._thread_caps_configured = False
140+
self._inference_runtime_configured = False
141+
# Kept alive to persist the process-wide native BLAS thread cap (set in
142+
# ensure_inference_runtime_configured); never used as a context manager.
143+
self._blas_limiter: object | None = None
140144

141145
def setup(self) -> None:
142146
"""Register the nightly background scan task."""
@@ -162,32 +166,49 @@ async def close(self) -> None:
162166
if workers:
163167
await asyncio.gather(*workers, return_exceptions=True)
164168

165-
def ensure_thread_caps_configured(self) -> None:
169+
def ensure_inference_runtime_configured(self) -> None:
166170
"""
167-
Cap PyTorch threading for analysis inference (process-wide, applied once).
171+
Configure the on-device inference runtime for analysis (process-wide, applied once).
168172
169173
Torch-backed analysis providers call this at the start of their handle_async_init,
170174
before loading their models.
171175
"""
172-
# Lazy torch import: only torch-backed providers call this, so a host running no
173-
# such provider never imports torch. Running before the first model load also lets
174-
# set_num_interop_threads take effect (it can only be set before the first torch op).
175-
if self._thread_caps_configured:
176+
if self._inference_runtime_configured:
176177
return
178+
# Lazy imports: only torch-backed providers call this, so a host running no such
179+
# provider never imports torch/threadpoolctl. Running before the first model load
180+
# also lets set_num_interop_threads take effect (only settable before the first op).
181+
import threadpoolctl # noqa: PLC0415
177182
import torch # noqa: PLC0415
178183

179184
budget = self._aa_thread_budget()
180185
torch.set_num_threads(budget)
181186
with contextlib.suppress(RuntimeError):
182187
# set_num_interop_threads can only be called before the first torch op
183188
torch.set_num_interop_threads(1)
189+
# torch.set_num_threads only governs torch's own ops. The per-block librosa/numpy
190+
# feature extraction runs through the native BLAS pool (OpenBLAS), which otherwise
191+
# spawns a thread per core per worker and, across concurrent sessions, saturates
192+
# every core and starves playback. Cap it to the same budget; the limiter is kept
193+
# alive on the controller so the cap persists for the process.
194+
self._blas_limiter = threadpoolctl.threadpool_limits(limits=budget, user_api="blas")
195+
arm = is_arm()
196+
if arm:
197+
# NNPACK frequently fails to initialize on ARM SBCs (e.g. Raspberry Pi); torch
198+
# then re-logs "Could not initialize NNPACK" to stderr on every conv op. The fp32
199+
# conv fallback is used on those hosts regardless, so disabling it only removes
200+
# the log spam.
201+
with contextlib.suppress(Exception):
202+
torch.backends.nnpack.set_flags(False) # type: ignore[no-untyped-call]
184203
self.logger.info(
185-
"AudioAnalysis thread caps: torch intra=%d, torch interop=%d",
204+
"AudioAnalysis runtime: torch intra=%d interop=%d, blas<=%d, nnpack=%s",
186205
torch.get_num_threads(),
187206
torch.get_num_interop_threads(),
207+
budget,
208+
"off" if arm else "on",
188209
)
189210
# Only mark done once configuration actually succeeded, so a failure retries.
190-
self._thread_caps_configured = True
211+
self._inference_runtime_configured = True
191212

192213
@property
193214
def providers(self) -> list[AudioAnalysisProvider]:

music_assistant/helpers/util.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,11 @@ def get_total_system_memory() -> float:
130130
return 0.0
131131

132132

133+
def is_arm() -> bool:
134+
"""Return whether the host CPU is ARM-based (32- or 64-bit)."""
135+
return platform.machine().lower() in ("arm64", "aarch64", "armv8l", "armv7l")
136+
137+
133138
def verify_system_meets_requirements(
134139
*,
135140
feature_name: str,

music_assistant/providers/smart_fades/manifest.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"name": "Smart Fades",
55
"description": "Smart fades analyzes beat and downbeat detection, energy and musical key for smart crossfades.",
66
"codeowners": ["@music-assistant"],
7-
"requirements": ["beat-this==1.1.0", "nnAudio==0.3.3"],
7+
"requirements": ["beat-this==1.1.0", "nnAudio==0.3.3", "threadpoolctl==3.6.0"],
88
"credits": ["[Beat This!](https://github.com/CPJKU/beat_this)", "[skey](https://github.com/deezer/skey)"],
99
"documentation": "https://music-assistant.io/audio-analysis/smart-fades/",
1010
"multi_instance": false,

music_assistant/providers/smart_fades/provider.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from __future__ import annotations
44

55
import asyncio
6-
import platform
76
import time
87
from dataclasses import dataclass, field
98
from typing import TYPE_CHECKING, Any
@@ -16,6 +15,7 @@
1615
from torchaudio.transforms import SpectralCentroid
1716

1817
from music_assistant.constants import VERBOSE_LOG_LEVEL
18+
from music_assistant.helpers.util import is_arm
1919
from music_assistant.models.audio_analysis import AudioAnalysisData
2020
from music_assistant.models.audio_analysis_provider import AudioAnalysisProvider
2121

@@ -73,8 +73,8 @@ def __init__(
7373

7474
async def handle_async_init(self) -> None:
7575
"""Handle async initialization of the provider."""
76-
# Configure torch thread caps before loading any model (see the controller method).
77-
self.mass.streams.audio_analysis.ensure_thread_caps_configured()
76+
# Configure the inference runtime before loading any model (see the controller method).
77+
self.mass.streams.audio_analysis.ensure_inference_runtime_configured()
7878
(
7979
self._beat_this_model,
8080
self._beat_this_post_processor,
@@ -88,8 +88,7 @@ def _initialize_models(self) -> tuple[Any, ...]:
8888
"""Initialize ML models (runs in a thread to avoid blocking the event loop)."""
8989
beat_this_model = Spect2Frames(checkpoint_path="small0", device=self._device)
9090
# torch aarch64 wheels advertise fbgemm in supported_engines but its kernels are x86-only.
91-
is_arm = platform.machine().lower() in ("arm64", "aarch64", "armv8l", "armv7l")
92-
preference = ("qnnpack", "fbgemm") if is_arm else ("fbgemm", "qnnpack")
91+
preference = ("qnnpack", "fbgemm") if is_arm() else ("fbgemm", "qnnpack")
9392
supported_engines = torch.backends.quantized.supported_engines
9493
quantized_engine = next((e for e in preference if e in supported_engines), None)
9594
if quantized_engine is not None and torch.backends.quantized.engine != quantized_engine:

music_assistant/providers/sonic_analysis/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -338,8 +338,8 @@ async def handle_async_init(self) -> None:
338338
min_cpu_cores=MIN_CPU_CORES,
339339
require_ml_inference=True,
340340
)
341-
# Configure torch thread caps before loading the model (see the controller method).
342-
self.mass.streams.audio_analysis.ensure_thread_caps_configured()
341+
# Configure the inference runtime before loading the model (see the controller method).
342+
self.mass.streams.audio_analysis.ensure_inference_runtime_configured()
343343
(
344344
self._clap_model,
345345
self._clap_text_embeddings,

music_assistant/providers/sonic_analysis/manifest.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
"transformers==5.6.2",
99
"huggingface-hub==1.12.0",
1010
"PyYAML==6.0.3",
11-
"torchlibrosa==0.1.0"
11+
"torchlibrosa==0.1.0",
12+
"threadpoolctl==3.6.0"
1213
],
1314
"credits": [
1415
"[Microsoft CLAP](https://github.com/microsoft/CLAP)",

requirements_all.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ soundcloudpy==0.1.4
8686
sounddevice==0.5.5
8787
srptools>=1.0.0
8888
sxm==0.2.8
89+
threadpoolctl==3.6.0
8990
torch==2.11.0+cpu; sys_platform == 'linux' and platform_machine == 'x86_64'
9091
torch==2.11.0; sys_platform != 'linux' or platform_machine != 'x86_64'
9192
torchaudio==2.11.0+cpu; sys_platform == 'linux' and platform_machine == 'x86_64'

tests/controllers/streams/test_audio_analysis.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,19 @@ async def test_distribute_chunk_calls_all_providers() -> None:
4747
p2.process_pcm_chunk.assert_awaited_once_with(session_key, b"\x00" * 1024)
4848

4949

50-
def test_ensure_thread_caps_configured_is_idempotent() -> None:
51-
"""Torch thread caps are applied once per controller, however many providers init."""
50+
def test_ensure_inference_runtime_configured_is_idempotent() -> None:
51+
"""The inference runtime (torch + native BLAS caps) is configured once per controller."""
5252
controller = _make_controller()
53-
with patch("torch.set_num_threads") as set_threads, patch("torch.set_num_interop_threads"):
54-
controller.ensure_thread_caps_configured()
55-
controller.ensure_thread_caps_configured()
53+
with (
54+
patch("torch.set_num_threads") as set_threads,
55+
patch("torch.set_num_interop_threads"),
56+
patch("threadpoolctl.threadpool_limits") as blas_limits,
57+
patch("torch.backends.nnpack.set_flags"),
58+
):
59+
controller.ensure_inference_runtime_configured()
60+
controller.ensure_inference_runtime_configured()
5661
set_threads.assert_called_once()
62+
blas_limits.assert_called_once()
5763

5864

5965
@pytest.mark.asyncio

tests/core/test_helpers.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,3 +527,19 @@ def test_system_meets_requirements(cpu_cores: int, total_gb: float, expected: bo
527527
patch("music_assistant.helpers.util.get_total_system_memory", return_value=total_gb),
528528
):
529529
assert util.system_meets_requirements(min_memory_gb=6.0, min_cpu_cores=4) is expected
530+
531+
532+
@pytest.mark.parametrize(
533+
("machine", "expected"),
534+
[
535+
("aarch64", True),
536+
("arm64", True),
537+
("armv7l", True),
538+
("x86_64", False),
539+
("AMD64", False),
540+
],
541+
)
542+
def test_is_arm(machine: str, expected: bool) -> None:
543+
"""is_arm recognizes 32/64-bit ARM and rejects x86."""
544+
with patch("music_assistant.helpers.util.platform.machine", return_value=machine):
545+
assert util.is_arm() is expected

0 commit comments

Comments
 (0)