Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 27 additions & 14 deletions src/iohub/ngff/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import inspect
import itertools
import multiprocessing as mp
import os
import warnings
from collections import defaultdict
from collections.abc import Callable, Sequence
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from pathlib import Path
from typing import Any, Literal
Expand Down Expand Up @@ -283,7 +285,8 @@ def process_single_position(
output_channel_indices: list[slice] | list[list[int]] | None = None,
input_time_indices: list[int] | None = None,
output_time_indices: list[int] | None = None,
num_processes: int = 1,
num_processes: int | None = None,
num_threads: int = 1,
**kwargs,
) -> None:
"""
Expand Down Expand Up @@ -324,14 +327,26 @@ def process_single_position(
Must match input_channel_indices if not empty.
Defaults to None.
num_processes : int, optional
Number of simultaneous processes per position. Defaults to 1.
Deprecated. Use ``num_threads`` instead. When set, its value is
forwarded to ``num_threads``. If both are set to non-default values
and differ, ``num_threads`` takes precedence. Defaults to None.
num_threads : int, optional
Number of simultaneous threads per position. Defaults to 1.
kwargs : dict, optional
Additional arguments to pass to the function.
A dictionary with key "extra_metadata"
can be passed to be stored at a FOV level,
e.g.,
kwargs={"extra_metadata": {"Temperature": 37.5, "CO2_level": 0.5}}.
"""
if num_processes is not None:
warnings.warn(
"num_processes is deprecated. Use num_threads instead.",
DeprecationWarning,
stacklevel=2,
)
if num_threads < num_processes:
num_threads = num_processes
click.echo(f"Function to be applied: \t{func}")
click.echo(f"Input data path:\t{input_position_path}")
click.echo(f"Output data path:\t{output_position_path}")
Expand Down Expand Up @@ -395,21 +410,19 @@ def process_single_position(
output_position_path,
**kwargs,
)
num_processes = min(num_processes, len(flat_iterable), mp.cpu_count())
click.echo(f"\nStarting multiprocess pool with {num_processes} processes")
if num_processes <= 1:
# Run serially — Pool(1) with spawn unnecessarily forks a subprocess
cpu_count = os.cpu_count() or 1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will fail in the BRUNO. It will try to call the node's num of CPUs~128 depending on the node.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_workers = min(num_threads, len(flat_iterable), cpu_count)
click.echo(f"\nStarting thread pool with {num_workers} workers")
if num_workers <= 1:
for args in flat_iterable:
partial_apply_transform_to_czyx_and_save(*args)
else:
# NOTE: use spawn to work around tensorstore#61
context = mp.get_context("spawn")
with context.Pool(num_processes) as p:
p.starmap(
partial_apply_transform_to_czyx_and_save,
with ThreadPoolExecutor(max_workers=num_workers) as executor:
list(executor.map(
Comment thread
aofei-liu marked this conversation as resolved.
lambda args: partial_apply_transform_to_czyx_and_save(*args),
flat_iterable,
)
click.echo("Shut down multiprocess pool")
))
Comment thread
aofei-liu marked this conversation as resolved.
click.echo("Shut down thread pool")


# -- Pure utility functions ------------------------------------------------
Expand Down
12 changes: 3 additions & 9 deletions tests/ngff/test_ngff_utils.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, the only thing is if we can wrap this new test in the previous one together. maybe set a new variable in @given for thread_count or something like that.

Original file line number Diff line number Diff line change
Expand Up @@ -723,9 +723,10 @@ def test_match_indices_to_batches(indices, shard_size):
@given(
setup=process_single_position_setup(),
constant=st.integers(min_value=1, max_value=3),
num_threads=st.sampled_from([1, 2]),
)
@settings(max_examples=3, deadline=None)
def test_process_single_position(setup, constant):
def test_process_single_position(setup, constant, num_threads):
(
position_keys,
channel_names,
Expand All @@ -739,7 +740,6 @@ def test_process_single_position(setup, constant):
version,
) = setup

# Use the enhanced context manager to get both input and output store paths
with _temp_ome_zarr_stores(
position_keys=position_keys,
channel_names=channel_names,
Expand All @@ -750,16 +750,13 @@ def test_process_single_position(setup, constant):
dtype=dtype,
version=version,
) as (input_store_path, output_store_path):
# Populate Store with random data
populate_store(input_store_path, position_keys, shape, dtype)

# Choose a single position to process (e.g., the first one)
for position_key_tuple in position_keys:
input_position_path = input_store_path / Path(*position_key_tuple)
output_position_path = output_store_path / Path(*position_key_tuple)
kwargs = {"constant": constant, "extra_metadata": {"temp": 10}}

# Apply the transformation using process_single_position
process_single_position(
func=dummy_transform,
input_position_path=input_position_path,
Expand All @@ -768,18 +765,15 @@ def test_process_single_position(setup, constant):
output_channel_indices=channel_indices,
input_time_indices=time_indices,
output_time_indices=time_indices,
num_threads=num_threads,
**kwargs,
)

# Handle None for process_single_position_setup
if time_indices is None:
time_indices = list(range(shape[0]))
if channel_indices is None:
channel_indices = [[c] for c in range(shape[1])]

print("time_indices", time_indices)
print("channel_indices", channel_indices)
# Verify the transformation
iterable = itertools.product(time_indices, channel_indices)
for t_idx, chan_idx in iterable:
verify_transformation(
Expand Down
Loading