Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions src/iohub/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

from __future__ import annotations

from typing import Literal
from typing import TYPE_CHECKING, Any, Literal

from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field

if TYPE_CHECKING:
pass


class CompressorConfig(BaseModel):
Expand Down Expand Up @@ -34,7 +37,34 @@ class ZarrConfig(BaseModel):


class TensorStoreConfig(BaseModel):
"""Config for the TensorStore implementation."""
"""Config for the TensorStore implementation.

Parameters
----------
file_io_concurrency : int or None
Concurrency limit for TensorStore's ``file_io_concurrency``
resource. Raise above the default (32) on high-latency networked
filesystems (e.g. NFS) where the default under-saturates the link.
cache_pool_bytes : int or None
Aggregate byte budget for TensorStore's chunk cache pool. ``None``
disables caching.
recheck_cached_data : bool, "open" or None
Controls whether cached chunk data is re-validated on each read.
``None`` (default) uses the TensorStore driver default, which
revalidates cached metadata on every access — one stat/GETATTR per
chunk. ``"open"`` checks freshness only when the array is opened
and trusts the cache thereafter — recommended for long-running
read-heavy workloads on NFS/VAST where the underlying zarr files
do not change. ``False`` disables freshness checks entirely.
shared_context : tensorstore.Context or None
If set, reuse this ``ts.Context`` instead of building a new one.
Lets multiple ``open_ome_zarr`` calls share one cache pool and
thread pool. When set, the other context knobs on this config are
ignored; only ``recheck_cached_data`` still applies.
"""

# Allow non-pydantic types (ts.Context) in the model.
model_config = ConfigDict(arbitrary_types_allowed=True)

compressor: CompressorConfig = Field(default_factory=CompressorConfig)
data_copy_concurrency: int = Field(default=4, ge=1)
Expand All @@ -43,7 +73,9 @@ class TensorStoreConfig(BaseModel):
file_io_sync: bool = True
file_io_locking: Literal["auto", "disabled"] = "auto"
cache_pool_bytes: int | None = None
recheck_cached_data: bool | Literal["open"] | None = None
extra_context: dict | None = None
shared_context: Any = None # ts.Context; Any avoids importing tensorstore at import-time


ImplementationConfig = ZarrConfig | TensorStoreConfig
19 changes: 12 additions & 7 deletions src/iohub/core/implementations/tensorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def __init__(self, config: TensorStoreConfig | None = None):
def _context(self) -> ts.Context:
import tensorstore as ts

if self.config.shared_context is not None:
return self.config.shared_context

if not hasattr(self, "_ctx") or self._ctx is None:
ctx_opts = dict(self.config.context or {})
if self.config.data_copy_concurrency:
Expand Down Expand Up @@ -193,13 +196,15 @@ def open_array(self, group: zarr.Group, name: str) -> ts.TensorStore:
"driver": driver,
"kvstore": {"driver": "file", "path": key},
}
self._array_cache[key] = _ts_open(
spec,
open=True,
read=True,
write=writable,
context=self._context(),
)
open_kwargs: dict[str, Any] = {
"open": True,
"read": True,
"write": writable,
"context": self._context(),
}
if self.config.recheck_cached_data is not None:
open_kwargs["recheck_cached_data"] = self.config.recheck_cached_data
self._array_cache[key] = _ts_open(spec, **open_kwargs)
return self._array_cache[key]

# -- Array I/O ---------------------------------------------------------
Expand Down
65 changes: 65 additions & 0 deletions tests/core/test_shared_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Tests for TensorStoreConfig.shared_context passthrough."""

from __future__ import annotations

import pytest


def test_shared_context_defaults_to_none() -> None:
"""Default construction has shared_context=None for backwards compat."""
from iohub.core.config import TensorStoreConfig

cfg = TensorStoreConfig()
assert cfg.shared_context is None


def test_shared_context_accepts_ts_context() -> None:
"""Can assign a tensorstore.Context to shared_context."""
ts = pytest.importorskip("tensorstore")

from iohub.core.config import TensorStoreConfig

ctx = ts.Context({"cache_pool": {"total_bytes_limit": 1_000_000}})
cfg = TensorStoreConfig(shared_context=ctx)
assert cfg.shared_context is ctx


def test_shared_context_is_returned_by_implementation_context() -> None:
"""TensorStoreImplementation._context() returns the shared Context when set."""
ts = pytest.importorskip("tensorstore")

from iohub.core.config import TensorStoreConfig
from iohub.core.implementations.tensorstore import TensorStoreImplementation

shared = ts.Context({"cache_pool": {"total_bytes_limit": 2_000_000}})
cfg = TensorStoreConfig(shared_context=shared)
impl = TensorStoreImplementation(config=cfg)
assert impl._context() is shared


def test_two_implementations_share_one_context() -> None:
"""Two implementations built with the same shared_context return the same Context."""
ts = pytest.importorskip("tensorstore")

from iohub.core.config import TensorStoreConfig
from iohub.core.implementations.tensorstore import TensorStoreImplementation

shared = ts.Context({"cache_pool": {"total_bytes_limit": 500_000}})
cfg_a = TensorStoreConfig(shared_context=shared)
cfg_b = TensorStoreConfig(shared_context=shared)
impl_a = TensorStoreImplementation(config=cfg_a)
impl_b = TensorStoreImplementation(config=cfg_b)
assert impl_a._context() is impl_b._context()


def test_no_shared_context_falls_back_to_per_instance() -> None:
"""Without shared_context, each implementation builds its own Context."""
pytest.importorskip("tensorstore")

from iohub.core.config import TensorStoreConfig
from iohub.core.implementations.tensorstore import TensorStoreImplementation

cfg = TensorStoreConfig()
impl_a = TensorStoreImplementation(config=cfg)
impl_b = TensorStoreImplementation(config=cfg)
assert impl_a._context() is not impl_b._context()