Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/iohub/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class ZarrConfig(BaseModel):
class TensorStoreConfig(BaseModel):
"""Config for the TensorStore implementation."""

compressor: CompressorConfig = Field(default_factory=CompressorConfig)
data_copy_concurrency: int = Field(default=4, ge=1)
context: dict | None = None
file_io_concurrency: int | None = None
Expand Down
282 changes: 64 additions & 218 deletions src/iohub/core/implementations/tensorstore.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""TensorStore implementation (optional dependency)."""
"""TensorStore implementation -- zarr-python groups + TensorStore array I/O."""

from __future__ import annotations

import json
from pathlib import Path
from typing import TYPE_CHECKING, Any

import numpy as np
import zarr

from iohub.core.config import TensorStoreConfig
from iohub.core.protocol import ZarrImplementation
Expand All @@ -25,169 +24,6 @@
import tensorstore as ts


class _TsAttrs(dict):
"""Persistent attrs for a _TsGroup.

Reads from ``zarr.json`` (v3) or ``.zattrs`` (v2).
Writes always go to ``zarr.json``.
"""

def __init__(self, group: _TsGroup):
self._group = group
super().__init__(self._load())

def _load(self) -> dict:
zarr_json = self._group.path / "zarr.json"
if zarr_json.exists():
return json.loads(zarr_json.read_text()).get("attributes", {})
zattrs = self._group.path / ".zattrs"
if zattrs.exists():
return json.loads(zattrs.read_text())
return {}

def _save(self) -> None:
zarr_json = self._group.path / "zarr.json"
if zarr_json.exists():
meta = json.loads(zarr_json.read_text())
else:
meta = {"zarr_format": 3, "node_type": "group", "attributes": {}}
meta["attributes"] = dict(self)
zarr_json.write_text(json.dumps(meta))

def __setitem__(self, key, value):
super().__setitem__(key, value)
self._save()

def update(self, *args, **kwargs):
super().update(*args, **kwargs)
self._save()


def _detect_zarr_driver(path: Path) -> str:
"""Detect zarr format for a store root. Called once at open_group time."""
if (path / "zarr.json").exists():
return "zarr3"
if (path / ".zattrs").exists() or (path / ".zgroup").exists():
return "zarr2"
return "zarr3" # new stores will be created as v3 always


class _TsGroup:
"""Lightweight group handle (tensorstore has no native group concept)."""

def __init__(
self,
path: Path,
mode: str,
impl: TensorStoreImplementation,
zarr_driver: str = "zarr3",
root: Path | None = None,
):
if mode == "w-" and path.exists():
raise FileExistsError(f"Store already exists: {path}")
if mode in ("w", "w-", "a") and not path.exists():
path.mkdir(parents=True, exist_ok=True)
(path / "zarr.json").write_text('{"zarr_format": 3, "node_type": "group", "attributes": {}}')
self.path = path
self.mode = mode
self._impl = impl
self.zarr_driver = zarr_driver
self._root = root if root is not None else path

def create_group(self, name: str, overwrite: bool = False) -> _TsGroup:
sub = self.path / name
if sub.exists() and not overwrite:
return _TsGroup(path=sub, mode="a", impl=self._impl, zarr_driver=self.zarr_driver)
sub.mkdir(parents=True, exist_ok=True)
zarr_json = sub / "zarr.json"
if not zarr_json.exists() or overwrite:
zarr_json.write_text(json.dumps({"zarr_format": 3, "node_type": "group", "attributes": {}}))
return _TsGroup(path=sub, mode="a", impl=self._impl, zarr_driver=self.zarr_driver)

def __contains__(self, name: str) -> bool:
return self.get(name) is not None

def __delitem__(self, name: str) -> None:
import shutil

sub = self.path / name
if not sub.exists():
raise KeyError(name)
shutil.rmtree(sub)

def __getitem__(self, name: str):
result = self.get(name)
if result is None:
raise KeyError(name)
return result

def get(self, name: str, default=None):
sub = self.path / name
if not sub.is_dir():
return default
if self.zarr_driver == "zarr3":
try:
meta = json.loads((sub / "zarr.json").read_text())
if meta.get("node_type") == "array":
return self._impl.open_array(self, name)
if meta.get("node_type") == "group":
return _TsGroup(path=sub, mode="a", impl=self._impl, zarr_driver=self.zarr_driver, root=self._root)
except (OSError, ValueError):
pass
else:
if (sub / ".zarray").exists():
return self._impl.open_array(self, name)
if (sub / ".zgroup").exists():
return _TsGroup(path=sub, mode="a", impl=self._impl, zarr_driver=self.zarr_driver, root=self._root)
return default

@property
def store(self) -> _TsGroup:
return self

@property
def root(self) -> Path:
return self.path

@property
def name(self) -> str:
try:
rel = self.path.relative_to(self._root)
return "/" + str(rel) if str(rel) != "." else "/"
except ValueError:
return str(self.path)

@property
def basename(self) -> str:
return self.path.name

@property
def attrs(self) -> _TsAttrs:
if not hasattr(self, "_attrs_cache") or self._attrs_cache is None:
self._attrs_cache = _TsAttrs(self)
return self._attrs_cache

def tree(self, level: int | None = None) -> str:
lines = [self.basename]
self._tree_lines(self.path, "", level, 0, lines)
return "\n".join(lines)

def _tree_lines(self, p: Path, prefix: str, max_level: int | None, depth: int, lines: list) -> None:
if max_level is not None and depth >= max_level:
return
try:
children = sorted(
d for d in (entry.name for entry in p.iterdir()) if (p / d).is_dir() and not d.startswith(".")
)
except OSError:
return
for i, child in enumerate(children):
connector = "└── " if i == len(children) - 1 else "├── "
lines.append(f"{prefix}{connector}{child}")
extension = " " if i == len(children) - 1 else "│ "
self._tree_lines(p / child, prefix + extension, max_level, depth + 1, lines)


def _fill_value_for_spec(data_type: str, fill_value: int | float) -> object:
"""Return a TensorStore-compatible fill value for the given dtype string."""
if data_type == "bool":
Expand Down Expand Up @@ -228,11 +64,28 @@ def _spec_to_ts(spec: ArraySpec, path: str) -> dict:
return {"driver": "zarr3", "kvstore": {"driver": "file", "path": path}, "metadata": metadata}


_TS_IMPL_BASE = ZarrImplementation[_TsGroup, ts.TensorStore] if _TS_AVAILABLE else object # type: ignore[assignment]
def _resolve_array_path(group: zarr.Group, name: str) -> str:
"""Get filesystem path for an array within a zarr.Group."""
store = group.store
if not hasattr(store, "root"):
raise TypeError(f"TensorStore requires a LocalStore (filesystem) backend, got {type(store).__name__!r}.")
root = store.root
gpath = group.path
if gpath:
return str(root / gpath / name)
return str(root / name)


_TS_IMPL_BASE = ZarrImplementation[zarr.Group, ts.TensorStore] if _TS_AVAILABLE else object # type: ignore[assignment]


class TensorStoreImplementation(_TS_IMPL_BASE):
"""TensorStore-backed I/O implementation."""
"""Hybrid implementation: zarr-python groups + TensorStore array I/O.

Group operations (metadata, hierarchy) are delegated to zarr-python.
Array operations (create, read, write, downsample) use TensorStore
for high-performance I/O with configurable concurrency and caching.
"""

def __init__(self, config: TensorStoreConfig | None = None):
self.config = config or TensorStoreConfig()
Expand Down Expand Up @@ -267,56 +120,34 @@ def _context(self) -> ts.Context:
self._ctx = ts.Context(ctx_opts)
return self._ctx

# -- Group operations --------------------------------------------------

def open_group(self, path: StorePath, mode: str, zarr_format: int | None = None) -> _TsGroup:
p = Path(path)
return _TsGroup(path=p, mode=mode, impl=self, zarr_driver=_detect_zarr_driver(p), root=p)

def _iter_children(self, group: _TsGroup, node_type: str) -> list[str]:
"""Return sorted child names matching node_type ('group' or 'array')."""
p = Path(group.path)
if not p.is_dir():
return []
keys: list[str] = []
match group.zarr_driver:
case "zarr3":
for entry in p.iterdir():
d = entry.name
if not entry.is_dir() or d.startswith("."):
continue
try:
meta = json.loads((entry / "zarr.json").read_text())
if meta.get("node_type") == node_type:
keys.append(d)
except (OSError, ValueError):
pass
case "zarr2":
sentinel = {"group": ".zgroup", "array": ".zarray"}[node_type]
keys = [e.name for e in p.iterdir() if (p / e.name / sentinel).exists()]
return sorted(keys)

def group_keys(self, group: _TsGroup) -> list[str]:
return self._iter_children(group, "group")

def array_keys(self, group: _TsGroup) -> list[str]:
return self._iter_children(group, "array")

def close(self, group: _TsGroup) -> None:
pass # TensorStore handles are not persistent connections

def get_zarr_format(self, group: _TsGroup) -> int:
return 3 # TensorStore only supports zarr v3
# -- Group operations (delegated to zarr-python) -----------------------

def open_group(self, path: StorePath, mode: str, zarr_format: int | None = None) -> zarr.Group:
return zarr.open_group(path, mode=mode, zarr_format=zarr_format)

def group_keys(self, group: zarr.Group) -> list[str]:
return sorted(group.group_keys())

def array_keys(self, group: zarr.Group) -> list[str]:
return sorted(group.array_keys())

def close(self, group: zarr.Group) -> None:
group.store.close()

def get_zarr_format(self, group: zarr.Group) -> int:
return group.metadata.zarr_format

# -- Array lifecycle ---------------------------------------------------

def create_array(self, group: _TsGroup, name: str, spec: ArraySpec, *, overwrite: bool = False) -> ts.TensorStore:
ts_spec = _spec_to_ts(spec, str(Path(group.path) / name))
def create_array(self, group: zarr.Group, name: str, spec: ArraySpec, *, overwrite: bool = False) -> ts.TensorStore:
path = _resolve_array_path(group, name)
self._array_cache.pop(path, None)
ts_spec = _spec_to_ts(spec, path)
return _ts_open(ts_spec, create=True, delete_existing=overwrite, context=self._context())

def create_array_v2(
self,
group: _TsGroup,
group: zarr.Group,
name: str,
*,
shape: tuple[int, ...],
Expand All @@ -325,33 +156,48 @@ def create_array_v2(
fill_value: int = 0,
overwrite: bool = False,
) -> ts.TensorStore:
shuffle_map = {"noshuffle": 0, "shuffle": 1, "bitshuffle": 2}
comp = self.config.compressor
path = _resolve_array_path(group, name)
self._array_cache.pop(path, None)
# TensorStore zarr2 driver requires bool fill_value for bool dtype
resolved_dtype = np.dtype(dtype)
if resolved_dtype.kind == "b":
fill_value = bool(fill_value)
spec = {
"driver": "zarr2",
"kvstore": {"driver": "file", "path": str(Path(group.path) / name)},
"kvstore": {"driver": "file", "path": path},
"metadata": {
"shape": list(shape),
"chunks": list(chunks),
"dtype": np.dtype(dtype).str, # zarr2 uses NumPy dtype strings e.g. "<u2"
"compressor": {"id": "blosc", "cname": "lz4", "clevel": 5, "shuffle": 1},
"dtype": resolved_dtype.str,
"compressor": {
"id": "blosc",
"cname": comp.cname,
"clevel": comp.clevel,
"shuffle": shuffle_map.get(comp.shuffle, 2),
},
"fill_value": fill_value,
"order": "C",
"filters": None,
},
}
return _ts_open(spec, create=True, delete_existing=overwrite, context=self._context())

def open_array(self, group: _TsGroup, name: str) -> ts.TensorStore:
key = str(Path(group.path) / name)
def open_array(self, group: zarr.Group, name: str) -> ts.TensorStore:
key = _resolve_array_path(group, name)
if key not in self._array_cache:
driver = "zarr3" if group.metadata.zarr_format == 3 else "zarr2"
writable = not getattr(group.store, "read_only", False)
spec = {
"driver": group.zarr_driver,
"driver": driver,
"kvstore": {"driver": "file", "path": key},
}
self._array_cache[key] = _ts_open(
spec,
open=True,
read=True,
write=(group.mode != "r"),
write=writable,
context=self._context(),
)
return self._array_cache[key]
Expand Down
4 changes: 2 additions & 2 deletions src/iohub/core/protocol.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""ZarrImplementation Protocol -- the contract for zarr I/O backends.

Type parameter conventions:
G -- the native group handle type (e.g. ``zarr.Group``, ``_TsGroup``)
G -- the native group handle type (e.g. ``zarr.Group``)
A -- the native array handle type (e.g. ``zarr.Array``, ``ts.TensorStore``)

Concrete bindings per implementation:
ZarrPythonImplementation -> G=zarr.Group, A=zarr.Array
TensorStoreImplementation -> G=_TsGroup, A=ts.TensorStore
TensorStoreImplementation -> G=zarr.Group, A=ts.TensorStore
"""

from typing import Any, Protocol, runtime_checkable
Expand Down
Loading
Loading