diff --git a/src/iohub/core/__init__.py b/src/iohub/core/__init__.py index 24e1609f..5ba75b0a 100644 --- a/src/iohub/core/__init__.py +++ b/src/iohub/core/__init__.py @@ -1,7 +1,7 @@ """Core zarr implementation abstraction for iohub.""" from iohub.core.arrays import NGFFArray -from iohub.core.compat import ngff_version_for_format, zarr_format_for_version +from iohub.core.compat import get_ome_attrs, ngff_version_for_format, zarr_format_for_version from iohub.core.config import ( CompressorConfig, ImplementationConfig, diff --git a/src/iohub/core/compat.py b/src/iohub/core/compat.py index c2dfcafb..88e965fa 100644 --- a/src/iohub/core/compat.py +++ b/src/iohub/core/compat.py @@ -16,6 +16,16 @@ def zarr_format_for_version(version: NGFFVersion) -> ZarrFormat: raise ValueError(f"Unknown NGFF version: {version!r}. Supported: {list(NGFF_TO_ZARR_FORMAT)}") from err +def get_ome_attrs(attrs) -> dict: + """Extract OME metadata dict from zarr attrs, regardless of NGFF version. + + v0.5 stores wrap metadata under an ``"ome"`` key; v0.4 stores + place it flat in ``.zattrs``. This function returns the OME + metadata dict in both cases. + """ + return attrs.get("ome") or dict(attrs) + + def ngff_version_for_format(zarr_format: ZarrFormat) -> NGFFVersion: """Map zarr format integer to NGFF version string.""" try: diff --git a/src/iohub/core/config.py b/src/iohub/core/config.py index e96b879c..696d2f12 100644 --- a/src/iohub/core/config.py +++ b/src/iohub/core/config.py @@ -36,6 +36,7 @@ class ZarrConfig(BaseModel): class TensorStoreConfig(BaseModel): """Config for the TensorStore implementation.""" + compressor: CompressorConfig = Field(default_factory=CompressorConfig) data_copy_concurrency: int = Field(default=4, ge=1) context: dict | None = None file_io_concurrency: int | None = None diff --git a/src/iohub/core/implementations/tensorstore.py b/src/iohub/core/implementations/tensorstore.py index b55f0272..081e5b14 100644 --- a/src/iohub/core/implementations/tensorstore.py +++ b/src/iohub/core/implementations/tensorstore.py @@ -1,12 +1,11 @@ -"""TensorStore implementation (optional dependency).""" +"""TensorStore implementation -- zarr-python groups + TensorStore array I/O.""" from __future__ import annotations -import json -from pathlib import Path from typing import TYPE_CHECKING, Any import numpy as np +import zarr from iohub.core.config import TensorStoreConfig from iohub.core.protocol import ZarrImplementation @@ -25,169 +24,6 @@ import tensorstore as ts -class _TsAttrs(dict): - """Persistent attrs for a _TsGroup. - - Reads from ``zarr.json`` (v3) or ``.zattrs`` (v2). - Writes always go to ``zarr.json``. - """ - - def __init__(self, group: _TsGroup): - self._group = group - super().__init__(self._load()) - - def _load(self) -> dict: - zarr_json = self._group.path / "zarr.json" - if zarr_json.exists(): - return json.loads(zarr_json.read_text()).get("attributes", {}) - zattrs = self._group.path / ".zattrs" - if zattrs.exists(): - return json.loads(zattrs.read_text()) - return {} - - def _save(self) -> None: - zarr_json = self._group.path / "zarr.json" - if zarr_json.exists(): - meta = json.loads(zarr_json.read_text()) - else: - meta = {"zarr_format": 3, "node_type": "group", "attributes": {}} - meta["attributes"] = dict(self) - zarr_json.write_text(json.dumps(meta)) - - def __setitem__(self, key, value): - super().__setitem__(key, value) - self._save() - - def update(self, *args, **kwargs): - super().update(*args, **kwargs) - self._save() - - -def _detect_zarr_driver(path: Path) -> str: - """Detect zarr format for a store root. Called once at open_group time.""" - if (path / "zarr.json").exists(): - return "zarr3" - if (path / ".zattrs").exists() or (path / ".zgroup").exists(): - return "zarr2" - return "zarr3" # new stores will be created as v3 always - - -class _TsGroup: - """Lightweight group handle (tensorstore has no native group concept).""" - - def __init__( - self, - path: Path, - mode: str, - impl: TensorStoreImplementation, - zarr_driver: str = "zarr3", - root: Path | None = None, - ): - if mode == "w-" and path.exists(): - raise FileExistsError(f"Store already exists: {path}") - if mode in ("w", "w-", "a") and not path.exists(): - path.mkdir(parents=True, exist_ok=True) - (path / "zarr.json").write_text('{"zarr_format": 3, "node_type": "group", "attributes": {}}') - self.path = path - self.mode = mode - self._impl = impl - self.zarr_driver = zarr_driver - self._root = root if root is not None else path - - def create_group(self, name: str, overwrite: bool = False) -> _TsGroup: - sub = self.path / name - if sub.exists() and not overwrite: - return _TsGroup(path=sub, mode="a", impl=self._impl, zarr_driver=self.zarr_driver) - sub.mkdir(parents=True, exist_ok=True) - zarr_json = sub / "zarr.json" - if not zarr_json.exists() or overwrite: - zarr_json.write_text(json.dumps({"zarr_format": 3, "node_type": "group", "attributes": {}})) - return _TsGroup(path=sub, mode="a", impl=self._impl, zarr_driver=self.zarr_driver) - - def __contains__(self, name: str) -> bool: - return self.get(name) is not None - - def __delitem__(self, name: str) -> None: - import shutil - - sub = self.path / name - if not sub.exists(): - raise KeyError(name) - shutil.rmtree(sub) - - def __getitem__(self, name: str): - result = self.get(name) - if result is None: - raise KeyError(name) - return result - - def get(self, name: str, default=None): - sub = self.path / name - if not sub.is_dir(): - return default - if self.zarr_driver == "zarr3": - try: - meta = json.loads((sub / "zarr.json").read_text()) - if meta.get("node_type") == "array": - return self._impl.open_array(self, name) - if meta.get("node_type") == "group": - return _TsGroup(path=sub, mode="a", impl=self._impl, zarr_driver=self.zarr_driver, root=self._root) - except (OSError, ValueError): - pass - else: - if (sub / ".zarray").exists(): - return self._impl.open_array(self, name) - if (sub / ".zgroup").exists(): - return _TsGroup(path=sub, mode="a", impl=self._impl, zarr_driver=self.zarr_driver, root=self._root) - return default - - @property - def store(self) -> _TsGroup: - return self - - @property - def root(self) -> Path: - return self.path - - @property - def name(self) -> str: - try: - rel = self.path.relative_to(self._root) - return "/" + str(rel) if str(rel) != "." else "/" - except ValueError: - return str(self.path) - - @property - def basename(self) -> str: - return self.path.name - - @property - def attrs(self) -> _TsAttrs: - if not hasattr(self, "_attrs_cache") or self._attrs_cache is None: - self._attrs_cache = _TsAttrs(self) - return self._attrs_cache - - def tree(self, level: int | None = None) -> str: - lines = [self.basename] - self._tree_lines(self.path, "", level, 0, lines) - return "\n".join(lines) - - def _tree_lines(self, p: Path, prefix: str, max_level: int | None, depth: int, lines: list) -> None: - if max_level is not None and depth >= max_level: - return - try: - children = sorted( - d for d in (entry.name for entry in p.iterdir()) if (p / d).is_dir() and not d.startswith(".") - ) - except OSError: - return - for i, child in enumerate(children): - connector = "└── " if i == len(children) - 1 else "├── " - lines.append(f"{prefix}{connector}{child}") - extension = " " if i == len(children) - 1 else "│ " - self._tree_lines(p / child, prefix + extension, max_level, depth + 1, lines) - - def _fill_value_for_spec(data_type: str, fill_value: int | float) -> object: """Return a TensorStore-compatible fill value for the given dtype string.""" if data_type == "bool": @@ -228,11 +64,28 @@ def _spec_to_ts(spec: ArraySpec, path: str) -> dict: return {"driver": "zarr3", "kvstore": {"driver": "file", "path": path}, "metadata": metadata} -_TS_IMPL_BASE = ZarrImplementation[_TsGroup, ts.TensorStore] if _TS_AVAILABLE else object # type: ignore[assignment] +def _resolve_array_path(group: zarr.Group, name: str) -> str: + """Get filesystem path for an array within a zarr.Group.""" + store = group.store + if not hasattr(store, "root"): + raise TypeError(f"TensorStore requires a LocalStore (filesystem) backend, got {type(store).__name__!r}.") + root = store.root + gpath = group.path + if gpath: + return str(root / gpath / name) + return str(root / name) + + +_TS_IMPL_BASE = ZarrImplementation[zarr.Group, ts.TensorStore] if _TS_AVAILABLE else object # type: ignore[assignment] class TensorStoreImplementation(_TS_IMPL_BASE): - """TensorStore-backed I/O implementation.""" + """Hybrid implementation: zarr-python groups + TensorStore array I/O. + + Group operations (metadata, hierarchy) are delegated to zarr-python. + Array operations (create, read, write, downsample) use TensorStore + for high-performance I/O with configurable concurrency and caching. + """ def __init__(self, config: TensorStoreConfig | None = None): self.config = config or TensorStoreConfig() @@ -267,56 +120,34 @@ def _context(self) -> ts.Context: self._ctx = ts.Context(ctx_opts) return self._ctx - # -- Group operations -------------------------------------------------- - - def open_group(self, path: StorePath, mode: str, zarr_format: int | None = None) -> _TsGroup: - p = Path(path) - return _TsGroup(path=p, mode=mode, impl=self, zarr_driver=_detect_zarr_driver(p), root=p) - - def _iter_children(self, group: _TsGroup, node_type: str) -> list[str]: - """Return sorted child names matching node_type ('group' or 'array').""" - p = Path(group.path) - if not p.is_dir(): - return [] - keys: list[str] = [] - match group.zarr_driver: - case "zarr3": - for entry in p.iterdir(): - d = entry.name - if not entry.is_dir() or d.startswith("."): - continue - try: - meta = json.loads((entry / "zarr.json").read_text()) - if meta.get("node_type") == node_type: - keys.append(d) - except (OSError, ValueError): - pass - case "zarr2": - sentinel = {"group": ".zgroup", "array": ".zarray"}[node_type] - keys = [e.name for e in p.iterdir() if (p / e.name / sentinel).exists()] - return sorted(keys) - - def group_keys(self, group: _TsGroup) -> list[str]: - return self._iter_children(group, "group") - - def array_keys(self, group: _TsGroup) -> list[str]: - return self._iter_children(group, "array") - - def close(self, group: _TsGroup) -> None: - pass # TensorStore handles are not persistent connections - - def get_zarr_format(self, group: _TsGroup) -> int: - return 3 # TensorStore only supports zarr v3 + # -- Group operations (delegated to zarr-python) ----------------------- + + def open_group(self, path: StorePath, mode: str, zarr_format: int | None = None) -> zarr.Group: + return zarr.open_group(path, mode=mode, zarr_format=zarr_format) + + def group_keys(self, group: zarr.Group) -> list[str]: + return sorted(group.group_keys()) + + def array_keys(self, group: zarr.Group) -> list[str]: + return sorted(group.array_keys()) + + def close(self, group: zarr.Group) -> None: + group.store.close() + + def get_zarr_format(self, group: zarr.Group) -> int: + return group.metadata.zarr_format # -- Array lifecycle --------------------------------------------------- - def create_array(self, group: _TsGroup, name: str, spec: ArraySpec, *, overwrite: bool = False) -> ts.TensorStore: - ts_spec = _spec_to_ts(spec, str(Path(group.path) / name)) + def create_array(self, group: zarr.Group, name: str, spec: ArraySpec, *, overwrite: bool = False) -> ts.TensorStore: + path = _resolve_array_path(group, name) + self._array_cache.pop(path, None) + ts_spec = _spec_to_ts(spec, path) return _ts_open(ts_spec, create=True, delete_existing=overwrite, context=self._context()) def create_array_v2( self, - group: _TsGroup, + group: zarr.Group, name: str, *, shape: tuple[int, ...], @@ -325,14 +156,27 @@ def create_array_v2( fill_value: int = 0, overwrite: bool = False, ) -> ts.TensorStore: + shuffle_map = {"noshuffle": 0, "shuffle": 1, "bitshuffle": 2} + comp = self.config.compressor + path = _resolve_array_path(group, name) + self._array_cache.pop(path, None) + # TensorStore zarr2 driver requires bool fill_value for bool dtype + resolved_dtype = np.dtype(dtype) + if resolved_dtype.kind == "b": + fill_value = bool(fill_value) spec = { "driver": "zarr2", - "kvstore": {"driver": "file", "path": str(Path(group.path) / name)}, + "kvstore": {"driver": "file", "path": path}, "metadata": { "shape": list(shape), "chunks": list(chunks), - "dtype": np.dtype(dtype).str, # zarr2 uses NumPy dtype strings e.g. " ts.TensorStore: - key = str(Path(group.path) / name) + def open_array(self, group: zarr.Group, name: str) -> ts.TensorStore: + key = _resolve_array_path(group, name) if key not in self._array_cache: + driver = "zarr3" if group.metadata.zarr_format == 3 else "zarr2" + writable = not getattr(group.store, "read_only", False) spec = { - "driver": group.zarr_driver, + "driver": driver, "kvstore": {"driver": "file", "path": key}, } self._array_cache[key] = _ts_open( spec, open=True, read=True, - write=(group.mode != "r"), + write=writable, context=self._context(), ) return self._array_cache[key] diff --git a/src/iohub/core/protocol.py b/src/iohub/core/protocol.py index 99c1ae7b..7a3330e3 100644 --- a/src/iohub/core/protocol.py +++ b/src/iohub/core/protocol.py @@ -1,12 +1,12 @@ """ZarrImplementation Protocol -- the contract for zarr I/O backends. Type parameter conventions: - G -- the native group handle type (e.g. ``zarr.Group``, ``_TsGroup``) + G -- the native group handle type (e.g. ``zarr.Group``) A -- the native array handle type (e.g. ``zarr.Array``, ``ts.TensorStore``) Concrete bindings per implementation: ZarrPythonImplementation -> G=zarr.Group, A=zarr.Array - TensorStoreImplementation -> G=_TsGroup, A=ts.TensorStore + TensorStoreImplementation -> G=zarr.Group, A=ts.TensorStore """ from typing import Any, Protocol, runtime_checkable diff --git a/src/iohub/ngff/models.py b/src/iohub/ngff/models.py index c3553083..b17476c5 100644 --- a/src/iohub/ngff/models.py +++ b/src/iohub/ngff/models.py @@ -292,8 +292,7 @@ class LabelImageMeta(MetaBase): multiscales: list[MultiScaleMeta] # SHOULD: image-label with colors, properties, source image_label: PositionLabelMeta = Field(alias="image-label") - # only for OME-NGFF v0.5 - version: Literal["0.5"] | None = None + version: Literal["0.4", "0.5"] = "0.5" model_config = ConfigDict(extra="allow") @@ -308,8 +307,7 @@ class ImagesMeta(MetaBase): omero: OMEROMeta | None = None # labels group support labels: LabelsMeta | None = None - # only for OME-NGFF v0.5 - version: Literal["0.5"] | None = None + version: Literal["0.4", "0.5"] = "0.5" model_config = ConfigDict(extra="allow") diff --git a/src/iohub/ngff/nodes.py b/src/iohub/ngff/nodes.py index c01afe70..2459ae11 100644 --- a/src/iohub/ngff/nodes.py +++ b/src/iohub/ngff/nodes.py @@ -26,6 +26,7 @@ from pydantic import ValidationError from iohub.core import ArraySpec, NGFFArray, get_implementation +from iohub.core.compat import get_ome_attrs from iohub.core.config import ImplementationConfig from iohub.core.errors import StoreOpenError from iohub.core.protocol import ZarrImplementation @@ -92,7 +93,7 @@ def _open_store( try: zarr_format = None if mode in ("w", "w-") or (is_fs and mode == "a" and not store_path.exists()): - zarr_format = 3 + zarr_format = 2 if version == "0.4" else 3 root = impl.open_group(store_path, mode=mode, zarr_format=zarr_format) except (FileNotFoundError, FileExistsError, PermissionError): raise @@ -184,7 +185,7 @@ def zattrs(self): @property def maybe_wrapped_ome_attrs(self): """Container of OME metadata attributes.""" - return self.zattrs.get("ome") or self.zattrs + return get_ome_attrs(self.zattrs) @property def version(self) -> Literal["0.4", "0.5"]: @@ -225,16 +226,20 @@ def __len__(self): def __getitem__(self, key): key = normalize_path(str(key)) - znode = self.zgroup.get(key) - if not znode: - raise KeyError(key) levels = len(key.split("/")) - 1 item_type = self._MEMBER_TYPE for _ in range(levels): item_type = item_type._MEMBER_TYPE if issubclass(item_type, NGFFArray): - return item_type.from_handle(znode, self._impl) + try: + handle = self._impl.open_array(self._group, key) + except (FileNotFoundError, KeyError) as err: + raise KeyError(key) from err + return item_type.from_handle(handle, self._impl) else: + znode = self.zgroup.get(key) + if not znode: + raise KeyError(key) return item_type(group=znode, parse_meta=True, **self._child_attrs) def __setitem__(self, key, value): @@ -382,6 +387,10 @@ def _create_zarr_array( shards = tuple(c * s for c, s in zip(chunks, shards_ratio, strict=False)) else: shards = None + if shards is not None and self._zarr_format == 2: + raise ValueError( + "Sharding is not supported in Zarr v2 (OME-Zarr v0.4). Remove shards_ratio or use version='0.5'." + ) if self._zarr_format == 3: spec = ArraySpec.create( shape=shape, @@ -663,7 +672,10 @@ def __init__( def _parse_meta(self): """Parse multiscales and image-label metadata.""" try: - self.metadata = LabelImageMeta.model_validate(self.maybe_wrapped_ome_attrs) + attrs = dict(self.maybe_wrapped_ome_attrs) + if "version" not in attrs: + attrs["version"] = self.version + self.metadata = LabelImageMeta.model_validate(attrs) except ValidationError as e: _logger.warning(str(e)) self._warn_invalid_meta() @@ -683,10 +695,11 @@ def data(self) -> LabelsArray: def __getitem__(self, key: int | str) -> LabelsArray: key = normalize_path(str(key)) - znode = self.zgroup.get(key) - if not znode: - raise KeyError(key) - return LabelsArray.from_handle(znode, self._impl) + try: + handle = self._impl.open_array(self._group, key) + except (FileNotFoundError, KeyError) as err: + raise KeyError(key) from err + return LabelsArray.from_handle(handle, self._impl) def create_label( self, @@ -832,7 +845,7 @@ def _create_label_meta( ) ], image_label=image_label_meta, - version="0.5" if self.version == "0.5" else None, + version=self.version, ) elif dataset_meta.path not in self.metadata.multiscales[0].get_dataset_paths(): self.metadata.multiscales[0].datasets.append(dataset_meta) @@ -949,7 +962,10 @@ def _set_meta(self): def _parse_meta(self): try: - self.metadata = ImagesMeta.model_validate(self.maybe_wrapped_ome_attrs) + attrs = dict(self.maybe_wrapped_ome_attrs) + if "version" not in attrs: + attrs["version"] = self.version + self.metadata = ImagesMeta.model_validate(attrs) self._set_meta() except ValidationError as e: _logger.warning(str(e)) @@ -1179,7 +1195,7 @@ def _create_image_meta( ) ], omero=self._omero_meta(id=0, name=self._group.basename), - version="0.5" if self.version == "0.5" else None, + version=self.version, ) elif dataset_meta.path not in self.metadata.multiscales[0].get_dataset_paths(): self.metadata.multiscales[0].datasets.append(dataset_meta) @@ -1964,13 +1980,10 @@ def to_xarray(self) -> xr.DataArray: import dask.array as da - # Always use zarr-python for the dask array — other backends - # (e.g. TensorStore) may not integrate reliably with dask/xarray. + # Always use zarr-python for the dask array — TensorStore + # may not integrate reliably with dask/xarray. arr_name = self.metadata.multiscales[0].datasets[0].path - if isinstance(self._group, zarr.Group): - data = da.from_zarr(self._group[arr_name]) - else: - data = da.from_zarr(zarr.open_array(str(self._group.path / arr_name), mode="r")) + data = da.from_zarr(self._group[arr_name]) # Build axis unit lookup from OME metadata axis_units = {} for axis in self.axes: @@ -2480,7 +2493,7 @@ def _first_pos(self): try: well_path = self.metadata.wells[0].path well_grp = self.zgroup[well_path] - attrs = well_grp.attrs.get("ome") or dict(well_grp.attrs) + attrs = get_ome_attrs(well_grp.attrs) pos_name = attrs["well"]["images"][0]["path"] return Position( group=well_grp[pos_name], @@ -3066,8 +3079,6 @@ def open_ome_zarr( if _is_fslike(store_path): store_path = Path(store_path) _is_new_store = mode in ("w", "w-") or (mode == "a" and _is_fslike(store_path) and not store_path.exists()) - if version == "0.4" and _is_new_store: - raise ValueError("Creating new OME-Zarr v0.4 stores is not supported. Use version='0.5' instead.") parse_meta = _check_file_mode(store_path, mode, disable_path_checking=disable_path_checking) root, impl = _open_store( store_path, diff --git a/src/iohub/ngff/utils.py b/src/iohub/ngff/utils.py index c905c7c1..7001b7c9 100644 --- a/src/iohub/ngff/utils.py +++ b/src/iohub/ngff/utils.py @@ -418,10 +418,12 @@ def process_single_position( partial_apply_transform_to_czyx_and_save(*args) else: with ThreadPoolExecutor(max_workers=num_workers) as executor: - list(executor.map( - lambda args: partial_apply_transform_to_czyx_and_save(*args), - flat_iterable, - )) + list( + executor.map( + lambda args: partial_apply_transform_to_czyx_and_save(*args), + flat_iterable, + ) + ) click.echo("Shut down thread pool") diff --git a/tests/ngff/test_ngff.py b/tests/ngff/test_ngff.py index 64f74b14..fddf6fdb 100644 --- a/tests/ngff/test_ngff.py +++ b/tests/ngff/test_ngff.py @@ -25,6 +25,7 @@ if TYPE_CHECKING: from _typeshed import StrPath +from iohub.core.compat import get_ome_attrs from iohub.core.utils import pad_shape from iohub.ngff.models import TO_DICT_SETTINGS from iohub.ngff.nodes import ( @@ -45,7 +46,7 @@ y_dim_st = st.integers(1, 32) x_dim_st = st.integers(1, 32) channel_names_st = c_dim_st.flatmap(lambda c_dim: st.lists(short_text_st, min_size=c_dim, max_size=c_dim, unique=True)) -ngff_versions_st = st.just("0.5") +ngff_versions_st = st.sampled_from(["0.4", "0.5"]) short_alpha_numeric = st.text( alphabet=list(string.ascii_lowercase + string.ascii_uppercase + string.digits), min_size=1, @@ -238,28 +239,33 @@ def test_init_ome_zarr(channel_names, version): @pytest.mark.parametrize("mode", ["w", "w-"]) -def test_open_ome_zarr_v04_write_raises(tmp_path, mode): - """Creating new v0.4 stores must raise ValueError for all write modes.""" - with pytest.raises(ValueError, match=r"v0.4"): - open_ome_zarr( - tmp_path / "out.zarr", - layout="fov", - mode=mode, - channel_names=["DAPI"], - version="0.4", - ) - - -def test_open_ome_zarr_v04_append_new_path_raises(tmp_path): - """mode='a' on a nonexistent path is a new store and must also raise.""" - with pytest.raises(ValueError, match=r"v0.4"): - open_ome_zarr( - tmp_path / "nonexistent.zarr", - layout="fov", - mode="a", - channel_names=["DAPI"], - version="0.4", - ) +def test_open_ome_zarr_v04_write_succeeds(tmp_path, mode): + """Creating new v0.4 stores must succeed.""" + store_path = tmp_path / "out.zarr" + with open_ome_zarr( + store_path, + layout="fov", + mode=mode, + channel_names=["DAPI"], + version="0.4", + ) as ds: + assert ds.version == "0.4" + assert (store_path / ".zgroup").exists() + assert not (store_path / "zarr.json").exists() + + +def test_open_ome_zarr_v04_append_new_path_succeeds(tmp_path): + """mode='a' on a nonexistent path should create a v0.4 store.""" + store_path = tmp_path / "nonexistent.zarr" + with open_ome_zarr( + store_path, + layout="fov", + mode="a", + channel_names=["DAPI"], + version="0.4", + ) as ds: + assert ds.version == "0.4" + assert (store_path / ".zgroup").exists() @pytest.mark.parametrize("version", ["0.5"]) @@ -391,12 +397,14 @@ def test_write_ome_zarr(channels_and_random_5d, arr_name, version): channel_names, random_5d = channels_and_random_5d with _temp_ome_zarr(random_5d, channel_names, arr_name, version=version) as dataset: assert_allclose(dataset[arr_name][:], random_5d) - # round-trip test with the offical reader implementation - ext_reader = ome_zarr.reader.Reader(ome_zarr.io.parse_url(dataset.zgroup.store.root)) - node = next(iter(ext_reader())) - assert node.metadata["channel_names"] == channel_names - assert node.specs[0].datasets == [arr_name] - assert_allclose(node.data[0], random_5d) + if version == "0.5": + # round-trip test with the official reader implementation + # ome-zarr-py reader requires zarr-python Group with .store.root + ext_reader = ome_zarr.reader.Reader(ome_zarr.io.parse_url(dataset.zgroup.store.root)) + node = next(iter(ext_reader())) + assert node.metadata["channel_names"] == channel_names + assert node.specs[0].datasets == [arr_name] + assert_allclose(node.data[0], random_5d) @given( @@ -421,9 +429,10 @@ def test_create_zeros(ch_shape_dtype, arr_name, version): version=version, ) dataset.create_zeros(name=arr_name, shape=shape, dtype=dtype) - assert {e.name for e in (Path(store_path) / arr_name).iterdir()} == { - "zarr.json", - } + if version == "0.5": + assert (Path(store_path) / arr_name / "zarr.json").exists() + else: + assert (Path(store_path) / arr_name / ".zarray").exists() if version == "0.5": assert dataset[arr_name].metadata.dimension_names == ( "T", @@ -817,10 +826,7 @@ def test_set_transform_fov(ch_shape_dtype, arr_name, version): assert dataset.metadata.multiscales[0].coordinate_transformations == transform # read data with plain zarr group = zarr.open(store_path) - if version == "0.4": - maybe_ome = group.attrs - elif version == "0.5": - maybe_ome = group.attrs["ome"] + maybe_ome = get_ome_attrs(group.attrs) assert maybe_ome["multiscales"][0]["coordinateTransformations"] == [ translate.model_dump(**TO_DICT_SETTINGS) for translate in transform ] @@ -1149,7 +1155,7 @@ def test_create_well(row_names: list[str], col_names: list[str]): for row_name in row_names: for col_name in col_names: dataset.create_well(row_name, col_name) - plate_meta = dataset.zattrs.get("ome", dataset.zattrs)["plate"] + plate_meta = get_ome_attrs(dataset.zattrs)["plate"] assert [c["name"] for c in plate_meta["columns"]] == col_names assert [r["name"] for r in plate_meta["rows"]] == row_names @@ -1198,7 +1204,7 @@ def test_create_position(row, col, pos, version): version=version, ) _ = dataset.create_position(row_name=row, col_name=col, pos_name=pos) - ome = dataset.zgroup.attrs["ome"] + ome = get_ome_attrs(dataset.zgroup.attrs) assert [c["name"] for c in ome["plate"]["columns"]] == [col] assert [r["name"] for r in ome["plate"]["rows"]] == [row] assert (store_path / row / col / pos).is_dir() @@ -1249,7 +1255,7 @@ def test_create_positions(tmp_path, version): # Collect positions and compare those - get_metadata = lambda x: dict(x.zgroup.attrs["ome"]) + get_metadata = lambda x: get_ome_attrs(x.zgroup.attrs) single_plate_metadata = get_metadata(single) batched_plate_metadata = get_metadata(batched) @@ -1308,7 +1314,7 @@ def test_create_positions_with_tuple_variations(tmp_path, version): batched.create_positions(positions) # Verify metadata matches - get_metadata = lambda x: dict(x.zgroup.attrs["ome"]) + get_metadata = lambda x: get_ome_attrs(x.zgroup.attrs) single_meta = get_metadata(single) batched_meta = get_metadata(batched) @@ -1892,3 +1898,69 @@ def test_initialize_pyramid_invalid_dims(implementation, tmp_path): pos.create_zeros("0", shape=(1, 1, 2, 8, 8), dtype=np.float32) with pytest.raises(ValueError, match="not in dataset axes"): pos.initialize_pyramid(levels=2, dims={"w"}) + + +# ---------- v0.4 dedicated tests ---------- + + +def test_write_ome_zarr_v04_fov_roundtrip(tmp_path): + """Full round-trip: create v0.4 FOV store, write image, read back.""" + store_path = tmp_path / "v04.ome.zarr" + data = np.random.default_rng(42).random((1, 2, 3, 64, 64)).astype(np.float32) + with open_ome_zarr( + store_path, + layout="fov", + mode="w-", + channel_names=["A", "B"], + version="0.4", + ) as ds: + ds.create_image("0", data) + assert ds.version == "0.4" + # Verify v2 file structure + assert (store_path / ".zgroup").exists() + assert (store_path / ".zattrs").exists() + assert not (store_path / "zarr.json").exists() + # Re-open read-only + with open_ome_zarr(store_path, layout="fov", mode="r") as ds: + assert ds.version == "0.4" + assert_array_equal(ds["0"][:], data) + assert ds.channel_names == ["A", "B"] + + +def test_write_ome_zarr_v04_hcs_roundtrip(tmp_path): + """HCS plate creation with v0.4.""" + store_path = tmp_path / "v04_hcs.ome.zarr" + data = np.zeros((1, 2, 3, 32, 32), dtype=np.uint16) + with open_ome_zarr( + store_path, + layout="hcs", + mode="w-", + channel_names=["A", "B"], + version="0.4", + ) as plate: + pos = plate.create_position("A", "1", "0") + pos.create_image("0", data) + # Flat metadata, no "ome" wrapper + assert "plate" in plate.zattrs + assert "ome" not in plate.zattrs + with open_ome_zarr(store_path, layout="hcs", mode="r") as plate: + assert plate.version == "0.4" + assert_array_equal(plate["A/1/0"]["0"][:], data) + + +def test_sharding_raises_on_v04(tmp_path): + """Sharding must raise ValueError for v0.4.""" + store_path = tmp_path / "v04_shard.zarr" + with open_ome_zarr( + store_path, + layout="fov", + mode="w-", + channel_names=["A"], + version="0.4", + ) as ds: + with pytest.raises(ValueError, match="Sharding is not supported"): + ds.create_image( + "0", + np.zeros((1, 1, 1, 64, 64)), + shards_ratio=(1, 1, 1, 2, 2), + )