Skip to content

Commit e7db433

Browse files
committed
fix: adding ome-zarr v0.4 back
Signed-off-by: Sricharan Reddy Varra <sricharan.varra@biohub.org>
1 parent 9a3158f commit e7db433

5 files changed

Lines changed: 156 additions & 52 deletions

File tree

src/iohub/core/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class ZarrConfig(BaseModel):
3636
class TensorStoreConfig(BaseModel):
3737
"""Config for the TensorStore implementation."""
3838

39+
compressor: CompressorConfig = Field(default_factory=CompressorConfig)
3940
data_copy_concurrency: int = Field(default=4, ge=1)
4041
context: dict | None = None
4142
file_io_concurrency: int | None = None

src/iohub/core/implementations/tensorstore.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,20 @@ def __init__(self, group: _TsGroup):
3737
super().__init__(self._load())
3838

3939
def _load(self) -> dict:
40+
if self._group.zarr_driver == "zarr2":
41+
zattrs = self._group.path / ".zattrs"
42+
if zattrs.exists():
43+
return json.loads(zattrs.read_text())
44+
return {}
4045
zarr_json = self._group.path / "zarr.json"
4146
if zarr_json.exists():
4247
return json.loads(zarr_json.read_text()).get("attributes", {})
43-
zattrs = self._group.path / ".zattrs"
44-
if zattrs.exists():
45-
return json.loads(zattrs.read_text())
4648
return {}
4749

4850
def _save(self) -> None:
51+
if self._group.zarr_driver == "zarr2":
52+
(self._group.path / ".zattrs").write_text(json.dumps(dict(self)))
53+
return
4954
zarr_json = self._group.path / "zarr.json"
5055
if zarr_json.exists():
5156
meta = json.loads(zarr_json.read_text())
@@ -63,13 +68,17 @@ def update(self, *args, **kwargs):
6368
self._save()
6469

6570

66-
def _detect_zarr_driver(path: Path) -> str:
71+
def _detect_zarr_driver(path: Path, zarr_format: int | None = None) -> str:
6772
"""Detect zarr format for a store root. Called once at open_group time."""
73+
if zarr_format == 2:
74+
return "zarr2"
75+
if zarr_format == 3:
76+
return "zarr3"
6877
if (path / "zarr.json").exists():
6978
return "zarr3"
7079
if (path / ".zattrs").exists() or (path / ".zgroup").exists():
7180
return "zarr2"
72-
return "zarr3" # new stores will be created as v3 always
81+
return "zarr3"
7382

7483

7584
class _TsGroup:
@@ -87,7 +96,10 @@ def __init__(
8796
raise FileExistsError(f"Store already exists: {path}")
8897
if mode in ("w", "w-", "a") and not path.exists():
8998
path.mkdir(parents=True, exist_ok=True)
90-
(path / "zarr.json").write_text('{"zarr_format": 3, "node_type": "group", "attributes": {}}')
99+
if zarr_driver == "zarr2":
100+
(path / ".zgroup").write_text('{"zarr_format": 2}')
101+
else:
102+
(path / "zarr.json").write_text('{"zarr_format": 3, "node_type": "group", "attributes": {}}')
91103
self.path = path
92104
self.mode = mode
93105
self._impl = impl
@@ -99,9 +111,14 @@ def create_group(self, name: str, overwrite: bool = False) -> _TsGroup:
99111
if sub.exists() and not overwrite:
100112
return _TsGroup(path=sub, mode="a", impl=self._impl, zarr_driver=self.zarr_driver)
101113
sub.mkdir(parents=True, exist_ok=True)
102-
zarr_json = sub / "zarr.json"
103-
if not zarr_json.exists() or overwrite:
104-
zarr_json.write_text(json.dumps({"zarr_format": 3, "node_type": "group", "attributes": {}}))
114+
if self.zarr_driver == "zarr2":
115+
zgroup = sub / ".zgroup"
116+
if not zgroup.exists() or overwrite:
117+
zgroup.write_text('{"zarr_format": 2}')
118+
else:
119+
zarr_json = sub / "zarr.json"
120+
if not zarr_json.exists() or overwrite:
121+
zarr_json.write_text(json.dumps({"zarr_format": 3, "node_type": "group", "attributes": {}}))
105122
return _TsGroup(path=sub, mode="a", impl=self._impl, zarr_driver=self.zarr_driver)
106123

107124
def __contains__(self, name: str) -> bool:
@@ -271,7 +288,8 @@ def _context(self) -> ts.Context:
271288

272289
def open_group(self, path: StorePath, mode: str, zarr_format: int | None = None) -> _TsGroup:
273290
p = Path(path)
274-
return _TsGroup(path=p, mode=mode, impl=self, zarr_driver=_detect_zarr_driver(p), root=p)
291+
driver = _detect_zarr_driver(p, zarr_format)
292+
return _TsGroup(path=p, mode=mode, impl=self, zarr_driver=driver, root=p)
275293

276294
def _iter_children(self, group: _TsGroup, node_type: str) -> list[str]:
277295
"""Return sorted child names matching node_type ('group' or 'array')."""
@@ -306,7 +324,7 @@ def close(self, group: _TsGroup) -> None:
306324
pass # TensorStore handles are not persistent connections
307325

308326
def get_zarr_format(self, group: _TsGroup) -> int:
309-
return 3 # TensorStore only supports zarr v3
327+
return 2 if group.zarr_driver == "zarr2" else 3
310328

311329
# -- Array lifecycle ---------------------------------------------------
312330

@@ -325,14 +343,21 @@ def create_array_v2(
325343
fill_value: int = 0,
326344
overwrite: bool = False,
327345
) -> ts.TensorStore:
346+
shuffle_map = {"noshuffle": 0, "shuffle": 1, "bitshuffle": 2}
347+
comp = self.config.compressor
328348
spec = {
329349
"driver": "zarr2",
330350
"kvstore": {"driver": "file", "path": str(Path(group.path) / name)},
331351
"metadata": {
332352
"shape": list(shape),
333353
"chunks": list(chunks),
334354
"dtype": np.dtype(dtype).str, # zarr2 uses NumPy dtype strings e.g. "<u2"
335-
"compressor": {"id": "blosc", "cname": "lz4", "clevel": 5, "shuffle": 1},
355+
"compressor": {
356+
"id": "blosc",
357+
"cname": comp.cname,
358+
"clevel": comp.clevel,
359+
"shuffle": shuffle_map.get(comp.shuffle, 2),
360+
},
336361
"fill_value": fill_value,
337362
"order": "C",
338363
"filters": None,

src/iohub/ngff/nodes.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def _open_store(
9292
try:
9393
zarr_format = None
9494
if mode in ("w", "w-") or (is_fs and mode == "a" and not store_path.exists()):
95-
zarr_format = 3
95+
zarr_format = 2 if version == "0.4" else 3
9696
root = impl.open_group(store_path, mode=mode, zarr_format=zarr_format)
9797
except (FileNotFoundError, FileExistsError, PermissionError):
9898
raise
@@ -382,6 +382,10 @@ def _create_zarr_array(
382382
shards = tuple(c * s for c, s in zip(chunks, shards_ratio, strict=False))
383383
else:
384384
shards = None
385+
if shards is not None and self._zarr_format == 2:
386+
raise ValueError(
387+
"Sharding is not supported in Zarr v2 (OME-Zarr v0.4). Remove shards_ratio or use version='0.5'."
388+
)
385389
if self._zarr_format == 3:
386390
spec = ArraySpec.create(
387391
shape=shape,
@@ -3066,8 +3070,6 @@ def open_ome_zarr(
30663070
if _is_fslike(store_path):
30673071
store_path = Path(store_path)
30683072
_is_new_store = mode in ("w", "w-") or (mode == "a" and _is_fslike(store_path) and not store_path.exists())
3069-
if version == "0.4" and _is_new_store:
3070-
raise ValueError("Creating new OME-Zarr v0.4 stores is not supported. Use version='0.5' instead.")
30713073
parse_meta = _check_file_mode(store_path, mode, disable_path_checking=disable_path_checking)
30723074
root, impl = _open_store(
30733075
store_path,

src/iohub/ngff/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -418,10 +418,12 @@ def process_single_position(
418418
partial_apply_transform_to_czyx_and_save(*args)
419419
else:
420420
with ThreadPoolExecutor(max_workers=num_workers) as executor:
421-
list(executor.map(
422-
lambda args: partial_apply_transform_to_czyx_and_save(*args),
423-
flat_iterable,
424-
))
421+
list(
422+
executor.map(
423+
lambda args: partial_apply_transform_to_czyx_and_save(*args),
424+
flat_iterable,
425+
)
426+
)
425427
click.echo("Shut down thread pool")
426428

427429

tests/ngff/test_ngff.py

Lines changed: 107 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
y_dim_st = st.integers(1, 32)
4646
x_dim_st = st.integers(1, 32)
4747
channel_names_st = c_dim_st.flatmap(lambda c_dim: st.lists(short_text_st, min_size=c_dim, max_size=c_dim, unique=True))
48-
ngff_versions_st = st.just("0.5")
48+
ngff_versions_st = st.sampled_from(["0.4", "0.5"])
4949
short_alpha_numeric = st.text(
5050
alphabet=list(string.ascii_lowercase + string.ascii_uppercase + string.digits),
5151
min_size=1,
@@ -238,28 +238,33 @@ def test_init_ome_zarr(channel_names, version):
238238

239239

240240
@pytest.mark.parametrize("mode", ["w", "w-"])
241-
def test_open_ome_zarr_v04_write_raises(tmp_path, mode):
242-
"""Creating new v0.4 stores must raise ValueError for all write modes."""
243-
with pytest.raises(ValueError, match=r"v0.4"):
244-
open_ome_zarr(
245-
tmp_path / "out.zarr",
246-
layout="fov",
247-
mode=mode,
248-
channel_names=["DAPI"],
249-
version="0.4",
250-
)
251-
252-
253-
def test_open_ome_zarr_v04_append_new_path_raises(tmp_path):
254-
"""mode='a' on a nonexistent path is a new store and must also raise."""
255-
with pytest.raises(ValueError, match=r"v0.4"):
256-
open_ome_zarr(
257-
tmp_path / "nonexistent.zarr",
258-
layout="fov",
259-
mode="a",
260-
channel_names=["DAPI"],
261-
version="0.4",
262-
)
241+
def test_open_ome_zarr_v04_write_succeeds(tmp_path, mode):
242+
"""Creating new v0.4 stores must succeed."""
243+
store_path = tmp_path / "out.zarr"
244+
with open_ome_zarr(
245+
store_path,
246+
layout="fov",
247+
mode=mode,
248+
channel_names=["DAPI"],
249+
version="0.4",
250+
) as ds:
251+
assert ds.version == "0.4"
252+
assert (store_path / ".zgroup").exists()
253+
assert not (store_path / "zarr.json").exists()
254+
255+
256+
def test_open_ome_zarr_v04_append_new_path_succeeds(tmp_path):
257+
"""mode='a' on a nonexistent path should create a v0.4 store."""
258+
store_path = tmp_path / "nonexistent.zarr"
259+
with open_ome_zarr(
260+
store_path,
261+
layout="fov",
262+
mode="a",
263+
channel_names=["DAPI"],
264+
version="0.4",
265+
) as ds:
266+
assert ds.version == "0.4"
267+
assert (store_path / ".zgroup").exists()
263268

264269

265270
@pytest.mark.parametrize("version", ["0.5"])
@@ -391,12 +396,14 @@ def test_write_ome_zarr(channels_and_random_5d, arr_name, version):
391396
channel_names, random_5d = channels_and_random_5d
392397
with _temp_ome_zarr(random_5d, channel_names, arr_name, version=version) as dataset:
393398
assert_allclose(dataset[arr_name][:], random_5d)
394-
# round-trip test with the offical reader implementation
395-
ext_reader = ome_zarr.reader.Reader(ome_zarr.io.parse_url(dataset.zgroup.store.root))
396-
node = next(iter(ext_reader()))
397-
assert node.metadata["channel_names"] == channel_names
398-
assert node.specs[0].datasets == [arr_name]
399-
assert_allclose(node.data[0], random_5d)
399+
if version == "0.5":
400+
# round-trip test with the official reader implementation
401+
# ome-zarr-py reader requires zarr-python Group with .store.root
402+
ext_reader = ome_zarr.reader.Reader(ome_zarr.io.parse_url(dataset.zgroup.store.root))
403+
node = next(iter(ext_reader()))
404+
assert node.metadata["channel_names"] == channel_names
405+
assert node.specs[0].datasets == [arr_name]
406+
assert_allclose(node.data[0], random_5d)
400407

401408

402409
@given(
@@ -421,9 +428,10 @@ def test_create_zeros(ch_shape_dtype, arr_name, version):
421428
version=version,
422429
)
423430
dataset.create_zeros(name=arr_name, shape=shape, dtype=dtype)
424-
assert {e.name for e in (Path(store_path) / arr_name).iterdir()} == {
425-
"zarr.json",
426-
}
431+
if version == "0.5":
432+
assert (Path(store_path) / arr_name / "zarr.json").exists()
433+
else:
434+
assert (Path(store_path) / arr_name / ".zarray").exists()
427435
if version == "0.5":
428436
assert dataset[arr_name].metadata.dimension_names == (
429437
"T",
@@ -1198,7 +1206,7 @@ def test_create_position(row, col, pos, version):
11981206
version=version,
11991207
)
12001208
_ = dataset.create_position(row_name=row, col_name=col, pos_name=pos)
1201-
ome = dataset.zgroup.attrs["ome"]
1209+
ome = dict(dataset.zgroup.attrs) if version == "0.4" else dataset.zgroup.attrs["ome"]
12021210
assert [c["name"] for c in ome["plate"]["columns"]] == [col]
12031211
assert [r["name"] for r in ome["plate"]["rows"]] == [row]
12041212
assert (store_path / row / col / pos).is_dir()
@@ -1892,3 +1900,69 @@ def test_initialize_pyramid_invalid_dims(implementation, tmp_path):
18921900
pos.create_zeros("0", shape=(1, 1, 2, 8, 8), dtype=np.float32)
18931901
with pytest.raises(ValueError, match="not in dataset axes"):
18941902
pos.initialize_pyramid(levels=2, dims={"w"})
1903+
1904+
1905+
# ---------- v0.4 dedicated tests ----------
1906+
1907+
1908+
def test_write_ome_zarr_v04_fov_roundtrip(tmp_path):
1909+
"""Full round-trip: create v0.4 FOV store, write image, read back."""
1910+
store_path = tmp_path / "v04.ome.zarr"
1911+
data = np.random.default_rng(42).random((1, 2, 3, 64, 64)).astype(np.float32)
1912+
with open_ome_zarr(
1913+
store_path,
1914+
layout="fov",
1915+
mode="w-",
1916+
channel_names=["A", "B"],
1917+
version="0.4",
1918+
) as ds:
1919+
ds.create_image("0", data)
1920+
assert ds.version == "0.4"
1921+
# Verify v2 file structure
1922+
assert (store_path / ".zgroup").exists()
1923+
assert (store_path / ".zattrs").exists()
1924+
assert not (store_path / "zarr.json").exists()
1925+
# Re-open read-only
1926+
with open_ome_zarr(store_path, layout="fov", mode="r") as ds:
1927+
assert ds.version == "0.4"
1928+
assert_array_equal(ds["0"][:], data)
1929+
assert ds.channel_names == ["A", "B"]
1930+
1931+
1932+
def test_write_ome_zarr_v04_hcs_roundtrip(tmp_path):
1933+
"""HCS plate creation with v0.4."""
1934+
store_path = tmp_path / "v04_hcs.ome.zarr"
1935+
data = np.zeros((1, 2, 3, 32, 32), dtype=np.uint16)
1936+
with open_ome_zarr(
1937+
store_path,
1938+
layout="hcs",
1939+
mode="w-",
1940+
channel_names=["A", "B"],
1941+
version="0.4",
1942+
) as plate:
1943+
pos = plate.create_position("A", "1", "0")
1944+
pos.create_image("0", data)
1945+
# Flat metadata, no "ome" wrapper
1946+
assert "plate" in plate.zattrs
1947+
assert "ome" not in plate.zattrs
1948+
with open_ome_zarr(store_path, layout="hcs", mode="r") as plate:
1949+
assert plate.version == "0.4"
1950+
assert_array_equal(plate["A/1/0"]["0"][:], data)
1951+
1952+
1953+
def test_sharding_raises_on_v04(tmp_path):
1954+
"""Sharding must raise ValueError for v0.4."""
1955+
store_path = tmp_path / "v04_shard.zarr"
1956+
with open_ome_zarr(
1957+
store_path,
1958+
layout="fov",
1959+
mode="w-",
1960+
channel_names=["A"],
1961+
version="0.4",
1962+
) as ds:
1963+
with pytest.raises(ValueError, match="Sharding is not supported"):
1964+
ds.create_image(
1965+
"0",
1966+
np.zeros((1, 1, 1, 64, 64)),
1967+
shards_ratio=(1, 1, 1, 2, 2),
1968+
)

0 commit comments

Comments
 (0)