Skip to content

Commit a9a4a84

Browse files
committed
add tests
1 parent 851fc1c commit a9a4a84

1 file changed

Lines changed: 255 additions & 0 deletions

File tree

tests/ngff/test_ngff_utils.py

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@
77

88
import hypothesis.strategies as st
99
import numpy as np
10+
import pytest
1011
from hypothesis import assume, given, settings
1112
from numpy.typing import DTypeLike
1213

1314
from iohub.ngff import open_ome_zarr
1415
from iohub.ngff.utils import (
1516
_indices_to_shard_aligned_batches,
1617
_match_indices_to_batches,
18+
_V04_MAX_CHUNK_SIZE_BYTES,
19+
_V05_DEFAULT_ZYX_CHUNKS,
1720
apply_transform_to_tczyx_and_save,
1821
create_empty_plate,
1922
process_single_position,
@@ -786,3 +789,255 @@ def test_process_single_position(setup, constant, num_threads):
786789
dummy_transform,
787790
**kwargs,
788791
)
792+
793+
794+
# -- Explicit tests for version-specific chunk/shard defaults -----------------
795+
#
796+
# The hypothesis-based test_create_empty_plate exercises many parameter
797+
# combinations but does not assert the exact defaults the issue #401 spec
798+
# prescribes. These tests pin those defaults down so CI fails deterministically
799+
# if they regress, rather than relying on a favorable hypothesis draw.
800+
801+
802+
def _open_array(store_path: Path, position_key: tuple[str, str, str]):
803+
return open_ome_zarr(store_path)["/".join(position_key)].data
804+
805+
806+
@pytest.mark.parametrize(
807+
("shape", "expected_chunks"),
808+
[
809+
# Large shape: chunks clamped to DCA spec (16, 256, 256).
810+
((2, 2, 64, 1024, 1024), (1, 1, 16, 256, 256)),
811+
# Small Z: clamped to Z.
812+
((2, 2, 8, 1024, 1024), (1, 1, 8, 256, 256)),
813+
# Small YX: clamped to YX.
814+
((2, 2, 64, 128, 200), (1, 1, 16, 128, 200)),
815+
# Fully smaller than defaults.
816+
((1, 1, 4, 32, 32), (1, 1, 4, 32, 32)),
817+
],
818+
)
819+
def test_v05_default_chunks(tmp_path, shape, expected_chunks):
820+
"""v0.5 default chunks are DCA-aligned (16, 256, 256), clamped to shape."""
821+
store = tmp_path / "test.zarr"
822+
create_empty_plate(
823+
store_path=store,
824+
position_keys=[("A", "1", "0")],
825+
channel_names=[f"c{i}" for i in range(shape[1])],
826+
shape=shape,
827+
version="0.5",
828+
)
829+
arr = _open_array(store, ("A", "1", "0"))
830+
assert arr.chunks == expected_chunks
831+
832+
833+
def test_v05_default_shards_cover_zyx(tmp_path):
834+
"""v0.5 default shards have shape (1, 1, Z, Y, X) — one shard per (T, C)."""
835+
shape = (3, 2, 64, 1024, 1024)
836+
store = tmp_path / "test.zarr"
837+
create_empty_plate(
838+
store_path=store,
839+
position_keys=[("A", "1", "0")],
840+
channel_names=[f"c{i}" for i in range(shape[1])],
841+
shape=shape,
842+
version="0.5",
843+
)
844+
arr = _open_array(store, ("A", "1", "0"))
845+
# Shard spans one full (Z, Y, X) volume per (T, C) slot.
846+
assert arr.shards == (1, 1, shape[2], shape[3], shape[4])
847+
# And chunks stay DCA-aligned.
848+
assert arr.chunks == (1, 1, *_V05_DEFAULT_ZYX_CHUNKS)
849+
850+
851+
def test_v05_default_shards_with_non_divisible_zyx(tmp_path):
852+
"""Shards still cover the full (Z, Y, X) even when dims are not multiples of chunks."""
853+
# Z=20, Y=300, X=300 — none divide evenly into (16, 256, 256).
854+
shape = (1, 1, 20, 300, 300)
855+
store = tmp_path / "test.zarr"
856+
create_empty_plate(
857+
store_path=store,
858+
position_keys=[("A", "1", "0")],
859+
channel_names=["c0"],
860+
shape=shape,
861+
version="0.5",
862+
)
863+
arr = _open_array(store, ("A", "1", "0"))
864+
# Shard = chunk * ceil(dim/chunk) — must be >= dim along each axis.
865+
assert arr.chunks == (1, 1, 16, 256, 256)
866+
assert arr.shards[0] == 1
867+
assert arr.shards[1] == 1
868+
assert arr.shards[2] >= 20
869+
assert arr.shards[3] >= 300
870+
assert arr.shards[4] >= 300
871+
872+
873+
def test_v05_explicit_shards_ratio_is_honored(tmp_path):
874+
"""An explicit shards_ratio overrides the default."""
875+
shape = (4, 2, 16, 256, 256)
876+
store = tmp_path / "test.zarr"
877+
create_empty_plate(
878+
store_path=store,
879+
position_keys=[("A", "1", "0")],
880+
channel_names=[f"c{i}" for i in range(shape[1])],
881+
shape=shape,
882+
chunks=(1, 1, 16, 256, 256),
883+
shards_ratio=(2, 2, 1, 1, 1),
884+
version="0.5",
885+
)
886+
arr = _open_array(store, ("A", "1", "0"))
887+
assert arr.chunks == (1, 1, 16, 256, 256)
888+
assert arr.shards == (2, 2, 16, 256, 256)
889+
890+
891+
def test_v04_default_chunks_cover_full_zyx(tmp_path):
892+
"""v0.4 default chunks are (1, 1, Z, Y, X) when under the byte cap."""
893+
shape = (2, 2, 4, 64, 64)
894+
store = tmp_path / "test.zarr"
895+
create_empty_plate(
896+
store_path=store,
897+
position_keys=[("A", "1", "0")],
898+
channel_names=[f"c{i}" for i in range(shape[1])],
899+
shape=shape,
900+
version="0.4",
901+
)
902+
arr = _open_array(store, ("A", "1", "0"))
903+
assert arr.chunks == (1, 1, shape[2], shape[3], shape[4])
904+
905+
906+
def test_v04_default_chunks_capped_by_byte_limit(tmp_path):
907+
"""v0.4 chunks halve Z until the chunk fits under _V04_MAX_CHUNK_SIZE_BYTES."""
908+
# Pick a shape whose single (Z, Y, X) volume in float32 exceeds the cap.
909+
# float32 is 4 bytes; cap is 500 MB → a (256, 1024, 1024) volume is
910+
# 1 GiB, so the default must halve Z at least once.
911+
shape = (1, 1, 256, 1024, 1024)
912+
dtype = np.float32
913+
store = tmp_path / "test.zarr"
914+
create_empty_plate(
915+
store_path=store,
916+
position_keys=[("A", "1", "0")],
917+
channel_names=["c0"],
918+
shape=shape,
919+
dtype=dtype,
920+
version="0.4",
921+
)
922+
arr = _open_array(store, ("A", "1", "0"))
923+
t_chunk, c_chunk, z_chunk, y_chunk, x_chunk = arr.chunks
924+
assert (t_chunk, c_chunk) == (1, 1)
925+
assert (y_chunk, x_chunk) == (shape[3], shape[4])
926+
assert z_chunk < shape[2], "Z should have been halved to respect byte cap"
927+
bytes_per_chunk = z_chunk * y_chunk * x_chunk * np.dtype(dtype).itemsize
928+
assert bytes_per_chunk <= _V04_MAX_CHUNK_SIZE_BYTES
929+
930+
931+
def test_v04_default_has_no_sharding(tmp_path):
932+
"""v0.4 (Zarr v2) never has a sharding codec, regardless of the new defaults."""
933+
store = tmp_path / "test.zarr"
934+
create_empty_plate(
935+
store_path=store,
936+
position_keys=[("A", "1", "0")],
937+
channel_names=["c0"],
938+
shape=(2, 1, 8, 64, 64),
939+
version="0.4",
940+
)
941+
arr = _open_array(store, ("A", "1", "0"))
942+
assert arr.shards is None
943+
944+
945+
def test_v04_rejects_explicit_shards_ratio(tmp_path):
946+
"""Passing shards_ratio on a v0.4 store raises (Zarr v2 has no sharding)."""
947+
store = tmp_path / "test.zarr"
948+
with pytest.raises(ValueError, match="Sharding is not supported in Zarr v2"):
949+
create_empty_plate(
950+
store_path=store,
951+
position_keys=[("A", "1", "0")],
952+
channel_names=["c0"],
953+
shape=(2, 1, 8, 64, 64),
954+
shards_ratio=(1, 1, 1, 1, 1),
955+
version="0.4",
956+
)
957+
958+
959+
# -- Write path on sharded v0.5 stores ---------------------------------------
960+
#
961+
# iohub ships with ``zarrs`` as a required dependency and the
962+
# ``ZarrsCodecPipeline`` is the active codec pipeline. That pipeline handles
963+
# oindex writes into sharded Zarr v3 arrays correctly, so the upstream
964+
# zarr-python#2834 bug (which affects ``BatchedCodecPipeline``) does not
965+
# surface through iohub's default code path. These tests pin that behavior.
966+
967+
968+
def test_process_single_position_on_sharded_v05_store(tmp_path):
969+
"""process_single_position writes to a default-sharded v0.5 store correctly."""
970+
shape = (2, 1, 4, 16, 16)
971+
position_key = ("A", "1", "0")
972+
input_store = tmp_path / "input.zarr"
973+
output_store = tmp_path / "output.zarr"
974+
for store in (input_store, output_store):
975+
create_empty_plate(
976+
store_path=store,
977+
position_keys=[position_key],
978+
channel_names=["c0"],
979+
shape=shape,
980+
version="0.5",
981+
)
982+
populate_store(input_store, [position_key], shape, np.float32)
983+
984+
process_single_position(
985+
func=dummy_transform,
986+
input_position_path=input_store / Path(*position_key),
987+
output_position_path=output_store / Path(*position_key),
988+
input_channel_indices=[[0]],
989+
output_channel_indices=[[0]],
990+
input_time_indices=[0, 1],
991+
output_time_indices=[0, 1],
992+
constant=2,
993+
)
994+
995+
out_arr = _open_array(output_store, position_key)
996+
assert out_arr.shards == (1, 1, shape[2], shape[3], shape[4])
997+
998+
with open_ome_zarr(input_store) as in_ds, open_ome_zarr(output_store) as out_ds:
999+
in_data = in_ds["/".join(position_key)].data[:]
1000+
out_data = out_ds["/".join(position_key)].data[:]
1001+
np.testing.assert_array_almost_equal(out_data, dummy_transform(in_data, constant=2))
1002+
1003+
1004+
def test_apply_transform_to_tczyx_on_multi_time_shard(tmp_path):
1005+
"""Multi-time oindex write into a shard that spans multiple T slots.
1006+
1007+
This is the exact pattern that breaks under zarr-python's default
1008+
``BatchedCodecPipeline`` (upstream bug #2834). It works under iohub's
1009+
``ZarrsCodecPipeline``-backed config, which is the guarantee we want
1010+
to lock in.
1011+
"""
1012+
shape = (4, 1, 4, 16, 16)
1013+
shards_ratio = (2, 1, 1, 1, 1) # shard_t = 2 -> one write spans both T slots
1014+
position_key = ("A", "1", "0")
1015+
input_store = tmp_path / "input.zarr"
1016+
output_store = tmp_path / "output.zarr"
1017+
for store in (input_store, output_store):
1018+
create_empty_plate(
1019+
store_path=store,
1020+
position_keys=[position_key],
1021+
channel_names=["c0"],
1022+
shape=shape,
1023+
chunks=(1, 1, 4, 16, 16),
1024+
shards_ratio=shards_ratio,
1025+
version="0.5",
1026+
)
1027+
populate_store(input_store, [position_key], shape, np.float32)
1028+
1029+
apply_transform_to_tczyx_and_save(
1030+
func=dummy_transform,
1031+
input_position_path=input_store / Path(*position_key),
1032+
output_position_path=output_store / Path(*position_key),
1033+
input_channel_indices=[0],
1034+
output_channel_indices=[0],
1035+
input_time_indices=[0, 1],
1036+
output_time_indices=[0, 1],
1037+
constant=2,
1038+
)
1039+
1040+
with open_ome_zarr(input_store) as in_ds, open_ome_zarr(output_store) as out_ds:
1041+
in_slice = in_ds["/".join(position_key)].data[:2, :1]
1042+
out_slice = out_ds["/".join(position_key)].data[:2, :1]
1043+
np.testing.assert_array_almost_equal(out_slice, dummy_transform(in_slice, constant=2))

0 commit comments

Comments
 (0)