Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions src/iohub/ngff/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import inspect
import itertools
import multiprocessing as mp
import os
from collections import defaultdict
from collections.abc import Callable, Sequence
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from pathlib import Path
from typing import Any, Literal
Expand Down Expand Up @@ -395,21 +396,19 @@ def process_single_position(
output_position_path,
**kwargs,
)
num_processes = min(num_processes, len(flat_iterable), mp.cpu_count())
click.echo(f"\nStarting multiprocess pool with {num_processes} processes")
if num_processes <= 1:
# Run serially — Pool(1) with spawn unnecessarily forks a subprocess
cpu_count = os.cpu_count() or 1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will fail in the BRUNO. It will try to call the node's num of CPUs~128 depending on the node.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_workers = min(num_processes, len(flat_iterable), cpu_count)
Comment thread
srivarra marked this conversation as resolved.
Outdated
click.echo(f"\nStarting thread pool with {num_workers} workers")
if num_workers <= 1:
for args in flat_iterable:
partial_apply_transform_to_czyx_and_save(*args)
else:
# NOTE: use spawn to work around tensorstore#61
context = mp.get_context("spawn")
with context.Pool(num_processes) as p:
p.starmap(
partial_apply_transform_to_czyx_and_save,
with ThreadPoolExecutor(max_workers=num_workers) as executor:
list(executor.map(
Comment thread
aofei-liu marked this conversation as resolved.
lambda args: partial_apply_transform_to_czyx_and_save(*args),
flat_iterable,
)
click.echo("Shut down multiprocess pool")
))
Comment thread
aofei-liu marked this conversation as resolved.
click.echo("Shut down thread pool")


# -- Pure utility functions ------------------------------------------------
Expand Down
67 changes: 67 additions & 0 deletions tests/ngff/test_ngff_utils.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, the only thing is if we can wrap this new test in the previous one together. maybe set a new variable in @given for thread_count or something like that.

Original file line number Diff line number Diff line change
Expand Up @@ -792,3 +792,70 @@ def test_process_single_position(setup, constant):
dummy_transform,
**kwargs,
)


@given(
setup=process_single_position_setup(),
constant=st.integers(min_value=1, max_value=3),
)
@settings(max_examples=3, deadline=None)
def test_process_single_position_threaded(setup, constant):
(
position_keys,
channel_names,
shape,
chunks,
shards_ratio,
scale,
dtype,
channel_indices,
time_indices,
version,
) = setup

with _temp_ome_zarr_stores(
position_keys=position_keys,
channel_names=channel_names,
shape=shape,
chunks=chunks,
shards_ratio=shards_ratio,
scale=scale,
dtype=dtype,
version=version,
) as (input_store_path, output_store_path):
populate_store(input_store_path, position_keys, shape, dtype)

for position_key_tuple in position_keys:
input_position_path = input_store_path / Path(*position_key_tuple)
output_position_path = output_store_path / Path(*position_key_tuple)
kwargs = {"constant": constant, "extra_metadata": {"temp": 10}}

process_single_position(
func=dummy_transform,
input_position_path=input_position_path,
output_position_path=output_position_path,
input_channel_indices=channel_indices,
output_channel_indices=channel_indices,
input_time_indices=time_indices,
output_time_indices=time_indices,
num_processes=2,
**kwargs,
)

if time_indices is None:
time_indices = list(range(shape[0]))
if channel_indices is None:
channel_indices = [[c] for c in range(shape[1])]

iterable = itertools.product(time_indices, channel_indices)
for t_idx, chan_idx in iterable:
verify_transformation(
input_store_path,
output_store_path,
position_key_tuple,
shape,
t_idx,
chan_idx,
dummy_transform,
**kwargs,
)
Loading