Skip to content

Commit 1c739a2

Browse files
ieivanovclaude
andauthored
refactor: add num_workers/use_threads to process_single_position (#410)
* refactor: add num_workers/use_threads to process_single_position PR #396 replaced mp.Pool with ThreadPoolExecutor on the assumption that the transforms passed to process_single_position release the GIL and threads suffice. That holds for I/O-bound callers, but not for tensor-heavy CPU torch workloads (deskew, register, deconvolve): under threads, all concurrent task allocations live in one address space, and torch's CPU caching allocator never returns memory to the OS, so peak RSS climbs past the slurm cgroup limit. Process workers are still needed for those cases. Introduce two new public params and deprecate the old ones: * num_workers (default 1) — replaces num_processes (#396 already deprecated this) and num_threads. Both legacy names emit a DeprecationWarning and forward to num_workers. * use_threads (default False) — pick between ThreadPoolExecutor and ProcessPoolExecutor. Behaviour: * num_workers <= 1 -> serial loop in the calling process (matches the short-circuit added in #396). * num_workers > 1, use_threads=True -> ThreadPoolExecutor (the #396 default). * num_workers > 1, use_threads=False -> ProcessPoolExecutor with the spawn context (the new default). Two reasons to use ProcessPoolExecutor (and not mp.Pool, like before #396): 1. Silent worker death — a slurm cgroup OOM-kill of one worker leaves mp.Pool.starmap waiting forever for a result that never comes. ProcessPoolExecutor surfaces this as BrokenProcessPool, so the slurm job fails fast with a real traceback instead of hanging until walltime. 2. Spawn (not fork) — tensorstore's internal C++ threads aren't fork-safe (google/tensorstore#61), and multiprocessing defaults to fork on Linux. Verified end-to-end on a 57-timepoint deskew run (171 (T,C) tasks per fov, 8 workers): both pool variants and the serial path produce bit-identical output, and an intentional OOM under PPE fails within ~1 minute with BrokenProcessPool instead of hanging. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * refactor: close input zarr handle and short-circuit nan/zero check Two small cleanups in iohub.ngff.utils that surfaced while debugging deskew memory pressure: 1. `_apply_transform_to_czyx` opens the input zarr without a context manager, leaking the zarr group / metadata cache for the lifetime of the worker. Wrap in `with open_ome_zarr(...)` so the handle is released after each task. No measurable memory effect at the cgroup level — file-handle hygiene fix; matters most for very long task queues. 2. `_check_nan_n_zeros` materialised a full boolean mask of the input volume (via `np.all(arr == 0)`) before reducing it. Replace with `np.any(arr)`, which short-circuits in the numpy C reduction kernel as soon as it sees a truthy element and does not allocate a temp mask. The all-NaN branch only runs when `np.any` returned True (i.e. the array contains content or NaNs); skip it entirely for integer dtypes that can't represent NaN. Behaviour-preserving: produces the same return value as the previous implementation for all 3D and 4D inputs, including the per-channel "any channel empty" semantics for 4D arrays. Verified end-to-end on the deskew workload; bit-identical outputs. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * test: parametrise test_process_single_position over num_workers/use_threads Renames the hypothesis strategy from `num_threads` to `num_workers` to match the new public API, and adds a `use_threads` boolean strategy so the test exercises both the ProcessPoolExecutor (default) and ThreadPoolExecutor paths. The old test only covered serial + threads. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * test: assert num_processes/num_threads emit DeprecationWarning Adds a parametrized regression test that asserts both legacy kwargs trigger a DeprecationWarning when forwarded to num_workers. The warnings are otherwise invisible at runtime under Python's default filter (which suppresses DeprecationWarning raised from package code), so this is the only practical way to catch a future accidental removal of the shim. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * test: put repo root on PYTHONPATH so spawn workers can import tests/ `test_process_single_position` parametrises over `use_threads ∈ {True, False}`. With `use_threads=False`, iohub spins up a `ProcessPoolExecutor` with the `spawn` context. Spawn children re-initialise sys.path from the runtime defaults plus PYTHONPATH; they do not inherit pytest's `--import-mode=importlib` sys.path manipulation. Unpickling the test-local `dummy_transform` (which lives at `tests.ngff.test_ngff_utils.dummy_transform`) therefore fails with `ModuleNotFoundError: No module named 'tests'` and the worker dies, surfacing as `BrokenProcessPool` in the parent. Fix: prepend the repo root to PYTHONPATH (and to the parent's sys.path for symmetry) in `tests/conftest.py`. Spawn children inherit PYTHONPATH via the OS env, so they can now resolve `tests.ngff.test_ngff_utils` and unpickle the function. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * feat: honour SLURM_CPUS_PER_TASK when capping num_workers `os.cpu_count()` reports the host's total CPUs, not the cgroup CPU allocation. On a 128-core slurm node where the job was granted only 8 cores, capping `num_workers` at `os.cpu_count()` lets a caller oversubscribe the cgroup. Add `_available_cpus()` that prefers the `SLURM_CPUS_PER_TASK` env var when present and falls back to `os.cpu_count()` otherwise. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * refactor!: drop num_processes and num_threads kwargs Both were deprecated in the previous commit ('refactor: add num_workers/use_threads to process_single_position'), with shims that forwarded their values to num_workers. Drop the shims now — anything still passing num_processes / num_threads gets a TypeError pointing at the right argument name, which is more useful than a silent DeprecationWarning that callers may never see (Python suppresses DeprecationWarning raised from package code under the default filter). Removes the corresponding regression test (test_process_single_position_legacy_kwargs_deprecated) and the unused 'warnings' import. BREAKING CHANGE: callers of process_single_position must use num_workers (and, optionally, use_threads) instead of num_processes / num_threads. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * revert conftest.py * Revert "revert conftest.py" This reverts commit 0f86c59. --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent f281ac3 commit 1c739a2

3 files changed

Lines changed: 120 additions & 43 deletions

File tree

src/iohub/ngff/utils.py

Lines changed: 80 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
import inspect
44
import itertools
5+
import multiprocessing as mp
56
import os
6-
import warnings
77
from collections import defaultdict
88
from collections.abc import Callable, Sequence
9-
from concurrent.futures import ThreadPoolExecutor
9+
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
1010
from functools import partial
1111
from pathlib import Path
1212
from typing import Any, Literal
@@ -165,8 +165,8 @@ def _apply_transform_to_czyx(
165165
kwargs["input_time_index"] = input_time_index
166166

167167
click.echo(f"Processing t={input_time_index}, c={input_channel_indices}")
168-
input_dataset = open_ome_zarr(input_position_path, layout="fov", mode="r")
169-
czyx_data = input_dataset.data.oindex[input_time_index, input_channel_indices]
168+
with open_ome_zarr(input_position_path, layout="fov", mode="r") as input_dataset:
169+
czyx_data = input_dataset.data.oindex[input_time_index, input_channel_indices]
170170
if not _check_nan_n_zeros(czyx_data):
171171
return func(czyx_data, **kwargs)
172172
else:
@@ -279,6 +279,21 @@ def _slice_to_list(indices: list[int] | slice) -> list[int]:
279279
return indices
280280

281281

282+
def _available_cpus() -> int:
283+
"""Return the CPU count the current process is allowed to use.
284+
285+
Slurm exports ``SLURM_CPUS_PER_TASK`` for tasks that ask for more than
286+
one CPU, which reflects the cgroup CPU allocation rather than the
287+
host's total CPU count. Honouring it here prevents oversubscribing
288+
the cgroup when ``os.cpu_count()`` reports the whole node (e.g. 128)
289+
while slurm only granted us a few cores.
290+
"""
291+
slurm_cpus = os.environ.get("SLURM_CPUS_PER_TASK")
292+
if slurm_cpus and slurm_cpus.isdigit():
293+
return int(slurm_cpus)
294+
return os.cpu_count() or 1
295+
296+
282297
def process_single_position(
283298
func: Callable[[NDArray, Any], NDArray],
284299
input_position_path: Path,
@@ -287,8 +302,8 @@ def process_single_position(
287302
output_channel_indices: list[slice] | list[list[int]] | None = None,
288303
input_time_indices: list[int] | None = None,
289304
output_time_indices: list[int] | None = None,
290-
num_processes: int | None = None,
291-
num_threads: int = 1,
305+
num_workers: int = 1,
306+
use_threads: bool = False,
292307
**kwargs,
293308
) -> None:
294309
"""
@@ -328,27 +343,21 @@ def process_single_position(
328343
If empty, write to all channels.
329344
Must match input_channel_indices if not empty.
330345
Defaults to None.
331-
num_processes : int, optional
332-
Deprecated. Use ``num_threads`` instead. When set, its value is
333-
forwarded to ``num_threads``. If both are set to non-default values
334-
and differ, ``num_threads`` takes precedence. Defaults to None.
335-
num_threads : int, optional
336-
Number of simultaneous threads per position. Defaults to 1.
346+
num_workers : int, optional
347+
Number of simultaneous workers (processes or threads) per position.
348+
If <= 1, the work is performed serially in the calling process.
349+
Defaults to 1.
350+
use_threads : bool, optional
351+
If True, parallelize across threads via ``ThreadPoolExecutor``;
352+
otherwise spawn worker processes via ``ProcessPoolExecutor``.
353+
Defaults to False.
337354
kwargs : dict, optional
338355
Additional arguments to pass to the function.
339356
A dictionary with key "extra_metadata"
340357
can be passed to be stored at a FOV level,
341358
e.g.,
342359
kwargs={"extra_metadata": {"Temperature": 37.5, "CO2_level": 0.5}}.
343360
"""
344-
if num_processes is not None:
345-
warnings.warn(
346-
"num_processes is deprecated. Use num_threads instead.",
347-
DeprecationWarning,
348-
stacklevel=2,
349-
)
350-
if num_threads < num_processes:
351-
num_threads = num_processes
352361
click.echo(f"Function to be applied: \t{func}")
353362
click.echo(f"Input data path:\t{input_position_path}")
354363
click.echo(f"Output data path:\t{output_position_path}")
@@ -412,39 +421,70 @@ def process_single_position(
412421
output_position_path,
413422
**kwargs,
414423
)
415-
cpu_count = os.cpu_count() or 1
416-
num_workers = min(num_threads, len(flat_iterable), cpu_count)
417-
click.echo(f"\nStarting thread pool with {num_workers} workers")
424+
num_workers = min(num_workers, len(flat_iterable), _available_cpus())
418425
if num_workers <= 1:
426+
click.echo("\nRunning serially in the calling process (num_workers <= 1)")
419427
for args in flat_iterable:
420428
partial_apply_transform_to_czyx_and_save(*args)
429+
click.echo("Done")
430+
elif use_threads:
431+
click.echo(f"\nStarting thread pool with {num_workers} threads")
432+
with ThreadPoolExecutor(max_workers=num_workers) as p:
433+
futures = [
434+
p.submit(partial_apply_transform_to_czyx_and_save, *args)
435+
for args in flat_iterable
436+
]
437+
for fut in as_completed(futures):
438+
fut.result()
439+
click.echo("Shut down thread pool")
421440
else:
422-
with ThreadPoolExecutor(max_workers=num_workers) as executor:
423-
list(
424-
executor.map(
425-
lambda args: partial_apply_transform_to_czyx_and_save(*args),
426-
flat_iterable,
427-
)
428-
)
429-
click.echo("Shut down thread pool")
441+
click.echo(f"\nStarting multiprocess pool with {num_workers} processes")
442+
# NOTE: spawn (not fork) — tensorstore runs internal C++ threads
443+
# that are not fork-safe, so a forked worker can deadlock or
444+
# segfault before our code runs. See google/tensorstore#61.
445+
# NOTE: ProcessPoolExecutor (not mp.Pool) so silent worker death
446+
# (e.g. cgroup OOM-kill) surfaces as BrokenProcessPool instead
447+
# of hanging indefinitely on pool.starmap.
448+
context = mp.get_context("spawn")
449+
with ProcessPoolExecutor(
450+
max_workers=num_workers, mp_context=context
451+
) as p:
452+
futures = [
453+
p.submit(partial_apply_transform_to_czyx_and_save, *args)
454+
for args in flat_iterable
455+
]
456+
for fut in as_completed(futures):
457+
fut.result()
458+
click.echo("Shut down multiprocess pool")
430459

431460

432461
# -- Pure utility functions ------------------------------------------------
433462

434463

435464
def _check_nan_n_zeros(input_array) -> bool:
436465
"""Checks if any of the channels are all zeros or nans."""
437-
if len(input_array.shape) == 3:
438-
if np.all(input_array == 0) or np.all(np.isnan(input_array)):
439-
return True
440-
elif len(input_array.shape) == 4:
441-
num_channels = input_array.shape[0]
442-
for c in range(num_channels):
443-
zyx_array = input_array[c, :, :, :]
444-
if np.all(zyx_array == 0) or np.all(np.isnan(zyx_array)):
445-
return True
466+
if input_array.ndim == 3:
467+
return _zyx_is_all_zero_or_nan(input_array)
468+
elif input_array.ndim == 4:
469+
return any(_zyx_is_all_zero_or_nan(input_array[c]) for c in range(input_array.shape[0]))
446470
else:
447471
raise ValueError("Input array must be 3D or 4D")
472+
473+
474+
def _zyx_is_all_zero_or_nan(zyx_array) -> bool:
475+
"""All-zero or all-NaN test that short-circuits on the first counter-example.
476+
477+
`np.any(arr)` returns False iff every element is 0/False, and short-circuits
478+
in C as soon as it finds a truthy value. The previous `np.all(arr == 0)`
479+
materialised a full boolean mask of the input volume before reducing it.
480+
"""
481+
if not np.any(zyx_array):
482+
return True # all zeros
483+
# NaN is truthy in numpy bool context, so the explicit NaN check is only
484+
# needed when np.any returned True (otherwise the array is all zeros and
485+
# would not reach here).
486+
if zyx_array.dtype.kind == "f" and np.isnan(zyx_array).all():
487+
return True # all NaN
448488
return False
449489

450490

tests/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import csv
22
import os
33
import shutil
4+
import sys
45
from pathlib import Path
56

67
import fsspec
@@ -13,6 +14,19 @@
1314
settings.load_profile("default")
1415

1516

17+
# Make the repo root importable from `multiprocessing` spawn children so that
18+
# tests using ProcessPoolExecutor (e.g. test_process_single_position with
19+
# use_threads=False) can unpickle helpers like `tests.ngff.test_ngff_utils.
20+
# dummy_transform`. pytest's `--import-mode=importlib` only manipulates the
21+
# parent process's sys.path, not the env that spawn children inherit.
22+
_REPO_ROOT = Path(__file__).resolve().parent.parent
23+
os.environ["PYTHONPATH"] = os.pathsep.join(
24+
[str(_REPO_ROOT)] + ([os.environ["PYTHONPATH"]] if os.environ.get("PYTHONPATH") else [])
25+
)
26+
if str(_REPO_ROOT) not in sys.path:
27+
sys.path.insert(0, str(_REPO_ROOT))
28+
29+
1630
@pytest.fixture
1731
def rng():
1832
return np.random.default_rng(42)

tests/ngff/test_ngff_utils.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from iohub.core.compat import V04_MAX_CHUNK_SIZE_BYTES
1515
from iohub.ngff import open_ome_zarr
1616
from iohub.ngff.utils import (
17+
_available_cpus,
1718
_indices_to_shard_aligned_batches,
1819
_match_indices_to_batches,
1920
_V05_DEFAULT_ZYX_CHUNKS,
@@ -737,10 +738,11 @@ def test_match_indices_to_batches(indices, shard_size):
737738
@given(
738739
setup=process_single_position_setup(),
739740
constant=st.integers(min_value=1, max_value=3),
740-
num_threads=st.sampled_from([1, 2]),
741+
num_workers=st.sampled_from([1, 2]),
742+
use_threads=st.booleans(),
741743
)
742744
@settings(max_examples=3, deadline=None)
743-
def test_process_single_position(setup, constant, num_threads):
745+
def test_process_single_position(setup, constant, num_workers, use_threads):
744746
(
745747
position_keys,
746748
channel_names,
@@ -779,7 +781,8 @@ def test_process_single_position(setup, constant, num_threads):
779781
output_channel_indices=channel_indices,
780782
input_time_indices=time_indices,
781783
output_time_indices=time_indices,
782-
num_threads=num_threads,
784+
num_workers=num_workers,
785+
use_threads=use_threads,
783786
**kwargs,
784787
)
785788

@@ -802,6 +805,26 @@ def test_process_single_position(setup, constant, num_threads):
802805
)
803806

804807

808+
@pytest.mark.parametrize(
809+
("env", "expected_min", "expected_max"),
810+
[
811+
("4", 4, 4), # honour SLURM_CPUS_PER_TASK exactly
812+
(None, 1, None), # fall back to os.cpu_count() when unset
813+
("", 1, None), # fall back when empty
814+
("abc", 1, None), # fall back when non-numeric
815+
],
816+
)
817+
def test_available_cpus_honours_slurm_env(monkeypatch, env, expected_min, expected_max):
818+
if env is None:
819+
monkeypatch.delenv("SLURM_CPUS_PER_TASK", raising=False)
820+
else:
821+
monkeypatch.setenv("SLURM_CPUS_PER_TASK", env)
822+
n = _available_cpus()
823+
assert n >= expected_min
824+
if expected_max is not None:
825+
assert n == expected_max
826+
827+
805828
# -- Explicit tests for version-specific chunk/shard defaults -----------------
806829
#
807830
# The hypothesis-based test_create_empty_plate exercises many parameter

0 commit comments

Comments
 (0)