Skip to content

Commit 9a3158f

Browse files
aofei-liuclaude
andauthored
refactor: replace multiprocessing.Pool with ThreadPoolExecutor in process_single_position (#396)
* refactor: replace multiprocessing.Pool with ThreadPoolExecutor in process_single_position The transform functions passed to process_single_position (numpy, scipy, PyTorch, ANTsPy) all release the GIL, making threads sufficient for parallelism. ThreadPoolExecutor avoids the overhead of spawn-based multiprocessing (re-importing libraries, serializing arguments) and plays better with Nextflow's per-task resource accounting. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix: handle os.cpu_count() returning None and add threaded path test Guard against os.cpu_count() returning None in containerized environments by falling back to 1. Add test_process_single_position_threaded to exercise the ThreadPoolExecutor code path with num_processes=2. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * refactor: deduplicate test_process_single_position and threaded variant Extract shared logic into _run_process_single_position helper so the threaded test delegates instead of duplicating the entire test body. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * refactor: inline helper and parameterize num_processes via hypothesis Remove _run_process_single_position helper and merge the two test functions into a single test_process_single_position that samples num_processes from [1, 2]. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * refactor: rename num_processes to num_threads with deprecation shim Add num_threads parameter to process_single_position and deprecate num_processes. When num_processes is passed, a DeprecationWarning is emitted and the value is forwarded to num_threads if num_threads is smaller. Update tests to use num_threads. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 9beb2a4 commit 9a3158f

2 files changed

Lines changed: 30 additions & 23 deletions

File tree

src/iohub/ngff/utils.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
import inspect
44
import itertools
5-
import multiprocessing as mp
5+
import os
6+
import warnings
67
from collections import defaultdict
78
from collections.abc import Callable, Sequence
9+
from concurrent.futures import ThreadPoolExecutor
810
from functools import partial
911
from pathlib import Path
1012
from typing import Any, Literal
@@ -283,7 +285,8 @@ def process_single_position(
283285
output_channel_indices: list[slice] | list[list[int]] | None = None,
284286
input_time_indices: list[int] | None = None,
285287
output_time_indices: list[int] | None = None,
286-
num_processes: int = 1,
288+
num_processes: int | None = None,
289+
num_threads: int = 1,
287290
**kwargs,
288291
) -> None:
289292
"""
@@ -324,14 +327,26 @@ def process_single_position(
324327
Must match input_channel_indices if not empty.
325328
Defaults to None.
326329
num_processes : int, optional
327-
Number of simultaneous processes per position. Defaults to 1.
330+
Deprecated. Use ``num_threads`` instead. When set, its value is
331+
forwarded to ``num_threads``. If both are set to non-default values
332+
and differ, ``num_threads`` takes precedence. Defaults to None.
333+
num_threads : int, optional
334+
Number of simultaneous threads per position. Defaults to 1.
328335
kwargs : dict, optional
329336
Additional arguments to pass to the function.
330337
A dictionary with key "extra_metadata"
331338
can be passed to be stored at a FOV level,
332339
e.g.,
333340
kwargs={"extra_metadata": {"Temperature": 37.5, "CO2_level": 0.5}}.
334341
"""
342+
if num_processes is not None:
343+
warnings.warn(
344+
"num_processes is deprecated. Use num_threads instead.",
345+
DeprecationWarning,
346+
stacklevel=2,
347+
)
348+
if num_threads < num_processes:
349+
num_threads = num_processes
335350
click.echo(f"Function to be applied: \t{func}")
336351
click.echo(f"Input data path:\t{input_position_path}")
337352
click.echo(f"Output data path:\t{output_position_path}")
@@ -395,21 +410,19 @@ def process_single_position(
395410
output_position_path,
396411
**kwargs,
397412
)
398-
num_processes = min(num_processes, len(flat_iterable), mp.cpu_count())
399-
click.echo(f"\nStarting multiprocess pool with {num_processes} processes")
400-
if num_processes <= 1:
401-
# Run serially — Pool(1) with spawn unnecessarily forks a subprocess
413+
cpu_count = os.cpu_count() or 1
414+
num_workers = min(num_threads, len(flat_iterable), cpu_count)
415+
click.echo(f"\nStarting thread pool with {num_workers} workers")
416+
if num_workers <= 1:
402417
for args in flat_iterable:
403418
partial_apply_transform_to_czyx_and_save(*args)
404419
else:
405-
# NOTE: use spawn to work around tensorstore#61
406-
context = mp.get_context("spawn")
407-
with context.Pool(num_processes) as p:
408-
p.starmap(
409-
partial_apply_transform_to_czyx_and_save,
420+
with ThreadPoolExecutor(max_workers=num_workers) as executor:
421+
list(executor.map(
422+
lambda args: partial_apply_transform_to_czyx_and_save(*args),
410423
flat_iterable,
411-
)
412-
click.echo("Shut down multiprocess pool")
424+
))
425+
click.echo("Shut down thread pool")
413426

414427

415428
# -- Pure utility functions ------------------------------------------------

tests/ngff/test_ngff_utils.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -723,9 +723,10 @@ def test_match_indices_to_batches(indices, shard_size):
723723
@given(
724724
setup=process_single_position_setup(),
725725
constant=st.integers(min_value=1, max_value=3),
726+
num_threads=st.sampled_from([1, 2]),
726727
)
727728
@settings(max_examples=3, deadline=None)
728-
def test_process_single_position(setup, constant):
729+
def test_process_single_position(setup, constant, num_threads):
729730
(
730731
position_keys,
731732
channel_names,
@@ -739,7 +740,6 @@ def test_process_single_position(setup, constant):
739740
version,
740741
) = setup
741742

742-
# Use the enhanced context manager to get both input and output store paths
743743
with _temp_ome_zarr_stores(
744744
position_keys=position_keys,
745745
channel_names=channel_names,
@@ -750,16 +750,13 @@ def test_process_single_position(setup, constant):
750750
dtype=dtype,
751751
version=version,
752752
) as (input_store_path, output_store_path):
753-
# Populate Store with random data
754753
populate_store(input_store_path, position_keys, shape, dtype)
755754

756-
# Choose a single position to process (e.g., the first one)
757755
for position_key_tuple in position_keys:
758756
input_position_path = input_store_path / Path(*position_key_tuple)
759757
output_position_path = output_store_path / Path(*position_key_tuple)
760758
kwargs = {"constant": constant, "extra_metadata": {"temp": 10}}
761759

762-
# Apply the transformation using process_single_position
763760
process_single_position(
764761
func=dummy_transform,
765762
input_position_path=input_position_path,
@@ -768,18 +765,15 @@ def test_process_single_position(setup, constant):
768765
output_channel_indices=channel_indices,
769766
input_time_indices=time_indices,
770767
output_time_indices=time_indices,
768+
num_threads=num_threads,
771769
**kwargs,
772770
)
773771

774-
# Handle None for process_single_position_setup
775772
if time_indices is None:
776773
time_indices = list(range(shape[0]))
777774
if channel_indices is None:
778775
channel_indices = [[c] for c in range(shape[1])]
779776

780-
print("time_indices", time_indices)
781-
print("channel_indices", channel_indices)
782-
# Verify the transformation
783777
iterable = itertools.product(time_indices, channel_indices)
784778
for t_idx, chan_idx in iterable:
785779
verify_transformation(

0 commit comments

Comments
 (0)