Skip to content

Commit 48399f6

Browse files
ziw-liuieivanovsrivarra
authored
Concatenate with OME-Zarr v0.5 and sharding (#104)
* add zarrs and organize dependency groups * configurable sharding * update chunking test * use clean env helper #96 (comment) * update example config * disable threading for the zarrs codec * test variable sharding in time * print the correct cluster name * allow blocking * fix typing * wip: test values of the concatenated array * fix monitoring * remove zarrs codec * tweak resource estimation * block in testing * require tensorstore * update dependency groups * combine context managers * ultrack lazy import * style * point to the pre-release * raise error is trying to concatenate zarr stores with pyramids --------- Co-authored-by: Ivan Ivanov <ivan.ivanov@czbiohub.org> Co-authored-by: Sricharan Reddy Varra <sricharan.varra@czbiohub.org>
1 parent 82bfbee commit 48399f6

6 files changed

Lines changed: 96 additions & 28 deletions

File tree

biahub/concatenate.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -238,22 +238,32 @@ def calculate_cropped_size(
238238
def concatenate(
239239
settings: ConcatenateSettings,
240240
output_dirpath: Path,
241-
sbatch_filepath: str = None,
241+
sbatch_filepath: str | None = None,
242242
local: bool = False,
243+
block: bool = False,
243244
monitor: bool = True,
244245
):
245-
"""
246-
Concatenate datasets (with optional cropping)
247-
248-
>> biahub concatenate -c ./concat.yml -o ./output_concat.zarr -j 8
246+
"""Concatenate datasets (with optional cropping).
247+
248+
Parameters
249+
----------
250+
settings : ConcatenateSettings
251+
Configuration settings for concatenation
252+
output_dirpath : Path
253+
Path to the output dataset
254+
sbatch_filepath : str | None, optional
255+
Path to the SLURM batch file, by default None
256+
local : bool, optional
257+
Whether to run locally or on a cluster, by default False
258+
block : bool, optional
259+
Whether to block until all the jobs are complete,
260+
by default False
261+
monitor : bool, optional
262+
Whether to monitor the jobs, by default True
249263
"""
250264
slurm_out_path = output_dirpath.parent / "slurm_output"
251265

252-
slicing_params = [
253-
settings.Z_slice,
254-
settings.Y_slice,
255-
settings.X_slice,
256-
]
266+
slicing_params = [settings.Z_slice, settings.Y_slice, settings.X_slice]
257267
(
258268
all_data_paths,
259269
all_channel_names,
@@ -274,6 +284,11 @@ def concatenate(
274284
all_voxel_sizes = []
275285
for path in all_data_paths:
276286
with open_ome_zarr(path) as dataset:
287+
if len(dataset.array_keys()) > 1:
288+
# TODO: https://github.com/czbiohub-sf/biahub/issues/192
289+
raise ValueError(
290+
"Concatenation of datasets with multiple arrays (pyramid levels) is not supported."
291+
)
277292
all_shapes.append(dataset.data.shape)
278293
all_dtypes.append(dataset.data.dtype)
279294
all_voxel_sizes.append(dataset.scale[-3:])
@@ -334,11 +349,12 @@ def concatenate(
334349
chunk_size = [1] + list(settings.chunks_czyx)
335350
else:
336351
chunk_size = settings.chunks_czyx
337-
338352
# Logic for creation of zarr and metadata
339353
output_metadata = {
340354
"shape": (len(input_time_indices), len(all_channel_names)) + tuple(cropped_shape_zyx),
341355
"chunks": chunk_size,
356+
"shards_ratio": settings.shards_ratio,
357+
"version": settings.output_ome_zarr_version,
342358
"scale": (1,) * 2 + tuple(output_voxel_size),
343359
"channel_names": all_channel_names,
344360
"dtype": dtype,
@@ -352,8 +368,9 @@ def concatenate(
352368
)
353369

354370
# Estimate resources
371+
batch_size = settings.shards_ratio[0] if settings.shards_ratio else 1
355372
num_cpus, gb_ram_per_cpu = estimate_resources(
356-
shape=[T, C, Z, Y, X], ram_multiplier=16, max_num_cpus=16
373+
shape=(T // batch_size, C, Z, Y, X), ram_multiplier=4 * batch_size, max_num_cpus=16
357374
)
358375
# Prepare SLURM arguments
359376
slurm_args = {
@@ -380,8 +397,9 @@ def concatenate(
380397
executor = submitit.AutoExecutor(folder=slurm_out_path, cluster=cluster)
381398
executor.update_parameters(**slurm_args)
382399

383-
click.echo("Submitting SLURM jobs...")
400+
click.echo(f"Submitting {cluster} jobs...")
384401
jobs = []
402+
385403
with submitit.helpers.clean_env(), executor.batch():
386404
for i, (
387405
input_position_path,
@@ -424,6 +442,9 @@ def concatenate(
424442
with log_path.open("w") as log_file:
425443
log_file.write("\n".join(job_ids))
426444

445+
if block:
446+
_ = [job.result() for job in jobs]
447+
427448
if monitor:
428449
monitor_jobs(jobs, all_data_paths)
429450

@@ -437,21 +458,22 @@ def concatenate(
437458
def concatenate_cli(
438459
config_filepath: Path,
439460
output_dirpath: str,
440-
sbatch_filepath: str = None,
461+
sbatch_filepath: str | None = None,
441462
local: bool = False,
442463
monitor: bool = True,
443464
):
444465
"""
445466
Concatenate datasets (with optional cropping)
446467
447-
>> biahub concatenate -c ./concat.yml -o ./output_concat.zarr -j 8
468+
>> biahub concatenate -c ./concat.yml -o ./output_concat.zarr
448469
"""
449470

450471
concatenate(
451472
settings=yaml_to_model(config_filepath, ConcatenateSettings),
452473
output_dirpath=Path(output_dirpath),
453474
sbatch_filepath=sbatch_filepath,
454475
local=local,
476+
block=False,
455477
monitor=monitor,
456478
)
457479

biahub/settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,9 @@ class ConcatenateSettings(MyBaseModel):
373373
Y_slice: Union[list, list[Union[list, Literal["all"]]], Literal["all"]] = "all"
374374
Z_slice: Union[list, list[Union[list, Literal["all"]]], Literal["all"]] = "all"
375375
chunks_czyx: Union[Literal[None], list[int]] = None
376+
shards_ratio: list[int] | None = None
376377
ensure_unique_positions: Optional[bool] = False
378+
output_ome_zarr_version: Literal["0.4", "0.5"] = "0.4"
377379

378380
@field_validator("concat_data_paths")
379381
@classmethod

biahub/track.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
)
3737
from biahub.settings import ProcessingInputChannel, TrackingSettings
3838

39+
# Lazy imports for ultrack - imported only when needed in specific functions
40+
3941

4042
def mem_nuc_contour(nuclei_prediction: ArrayLike, membrane_prediction: ArrayLike) -> ArrayLike:
4143
"""

pyproject.toml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ classifiers = [
2020

2121
# list package dependencies here
2222
dependencies = [
23+
"iohub>=0.3.0a2,<0.4",
2324
"stitch @ git+https://github.com/ahillsley/stitching@jen",
24-
"iohub>=0.2,<0.3",
2525
"matplotlib",
2626
"napari",
2727
"PyQt6",
@@ -35,7 +35,7 @@ dependencies = [
3535
"submitit",
3636
"torch",
3737
"tqdm",
38-
"waveorder==3.0.0a1",
38+
"waveorder==3.0.0a2",
3939
"largestinteriorrectangle",
4040
"antspyx",
4141
"pystackreg",
@@ -52,11 +52,13 @@ dependencies = [
5252
[project.optional-dependencies]
5353
segment = ["cellpose"]
5454

55-
track = ["ultrack>=0.6.3"]
55+
track = ["ultrack>=0.7.0rc2"]
56+
57+
shard = ["tensorstore"]
5658

5759
build = ["build", "twine"]
5860

59-
all = ["biahub[segment,track,build]"]
61+
all = ["biahub[segment,track,shard,build]"]
6062

6163
dev = [
6264
"biahub[all]",

settings/example_concatenate_settings.yml

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
# List of paths to concatenate - can use glob patterns
55
# Each path will be treated as a separate input dataset
66
concat_data_paths:
7-
- "/path/to/data1.zarr/*/*/*" # First dataset
8-
- "/path/to/data2.zarr/*/*/*" # Second dataset
7+
- "/path/to/data1.zarr/*/*/*" # First dataset
8+
- "/path/to/data2.zarr/*/*/*" # Second dataset
99
# - "/path/to/data3.zarr/A/1/0" # You can also specify exact positions
1010

1111
# Time indices to include in the output
@@ -22,8 +22,8 @@ time_indices: "all"
2222
# - For multiple datasets, specify channels for each:
2323
# [["DAPI"], ["GFP", "RFP"]] - Take DAPI from first dataset, GFP and RFP from second
2424
channel_names:
25-
- "all" # Include all channels from first dataset
26-
- "all" # Include all channels from second dataset
25+
- "all" # Include all channels from first dataset
26+
- "all" # Include all channels from second dataset
2727

2828
# Spatial cropping options for X dimension
2929
# Options:
@@ -55,12 +55,23 @@ Z_slice: "all"
5555
# - [1, 10, 100, 100]: Specify custom chunk sizes
5656
chunks_czyx: null
5757

58+
# Number of chunks in a shard for each dimension [T, C, Z, Y, X]
59+
# Options:
60+
# - null: No sharding
61+
# - [1, 1, 4, 8, 8]: Specify custom shards ratio
62+
shards_ratio: null
63+
64+
# Version of the OME-Zarr format to use for the output
65+
# Options:
66+
# - "0.4" (default)
67+
# - "0.5"
68+
output_ome_zarr_version: "0.4"
69+
5870
# Whether to ensure unique position names in the output
5971
# Options:
6072
# - false or null: Positions with the same name will overwrite each other
6173
# - true: Ensure unique position names by adding suffixes (e.g., A/1d1/0)
6274
ensure_unique_positions: null
63-
6475
# EXAMPLE USE CASES:
6576

6677
# 1. Basic concatenation of all data:

tests/test_concatenate.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import numpy as np
2+
import pytest
3+
14
from iohub import open_ome_zarr
25

36
from biahub.concatenate import concatenate
@@ -204,7 +207,13 @@ def test_concatenate_with_cropping(create_custom_plate, tmp_path, sbatch_file):
204207
assert output_X == x_end - x_start
205208

206209

207-
def test_concatenate_with_custom_chunks(create_custom_plate, tmp_path, sbatch_file):
210+
@pytest.mark.parametrize(
211+
["version", "shards_ratio_time"],
212+
[["0.4", 1], ["0.5", None], ["0.5", 1], ["0.5", 2], ["0.5", 5]],
213+
)
214+
def test_concatenate_with_custom_chunks(
215+
create_custom_plate, tmp_path, sbatch_file, version, shards_ratio_time
216+
):
208217
"""
209218
Test concatenating with custom chunk sizes
210219
"""
@@ -227,13 +236,22 @@ def test_concatenate_with_custom_chunks(create_custom_plate, tmp_path, sbatch_fi
227236
)
228237

229238
# Define custom chunk sizes
230-
custom_chunks = [1, 2, 4, 3] # [C, Z, Y, X]
239+
chunks = [1, 1, 2, 4, 3] # [C, Z, Y, X]
240+
if version == "0.5":
241+
if shards_ratio_time is None:
242+
shards_ratio = None
243+
else:
244+
shards_ratio = [shards_ratio_time, 1, 1, 2, 2]
245+
elif version == "0.4":
246+
shards_ratio = None
231247

232248
settings = ConcatenateSettings(
233249
concat_data_paths=[str(plate_1_path) + "/*/*/*", str(plate_2_path) + "/*/*/*"],
234250
channel_names=['all', 'all'],
235251
time_indices='all',
236-
chunks_czyx=custom_chunks,
252+
chunks_czyx=chunks[1:],
253+
shards_ratio=shards_ratio,
254+
output_ome_zarr_version=version,
237255
)
238256

239257
output_path = tmp_path / "output.zarr"
@@ -242,10 +260,21 @@ def test_concatenate_with_custom_chunks(create_custom_plate, tmp_path, sbatch_fi
242260
output_dirpath=output_path,
243261
sbatch_filepath=sbatch_file,
244262
local=True,
263+
monitor=False,
264+
block=True,
245265
)
246266

247-
# We can't easily check the chunks directly, but we can verify the operation completed successfully
248267
output_plate = open_ome_zarr(output_path)
268+
for pos_name, pos in output_plate.positions():
269+
assert pos.data.chunks == tuple(chunks)
270+
if version == "0.5" and shards_ratio is not None:
271+
assert pos.data.shards == tuple(c * s for c, s in zip(chunks, shards_ratio))
272+
np.testing.assert_array_equal(
273+
pos.data.numpy(),
274+
np.concatenate(
275+
[plate_1[pos_name].data.numpy(), plate_2[pos_name].data.numpy()], axis=1
276+
),
277+
)
249278

250279
# Check that the output plate has all the channels from the input plates
251280
output_channels = output_plate.channel_names

0 commit comments

Comments
 (0)