diff --git a/docs/examples/run_tile_hcs_well.py b/docs/examples/run_tile_hcs_well.py new file mode 100644 index 00000000..6a14172f --- /dev/null +++ b/docs/examples/run_tile_hcs_well.py @@ -0,0 +1,149 @@ +""" +Tile + Blend an HCS Well (Multi-FOV) +===================================== + +Create a synthetic HCS plate with 1 well and 4 FOVs arranged in a 2x2 +grid with physical overlap, then composite the FOVs into a single mosaic +and tile+blend with ``apply_func_tiled``. + +This demonstrates the full pipeline: +FOV compositing (``Well.to_xarray``) → tiling (``Tiler``) → blending (``apply_func_tiled``). +""" + +# %% +import os +import warnings +from tempfile import TemporaryDirectory + +import numpy as np + +from iohub.ngff import open_ome_zarr +from iohub.ngff.models import TransformationMeta +from iohub.tile import apply_func_tiled + +warnings.filterwarnings("ignore") + +# %% +# Create a synthetic HCS plate +# ------------------------------ +# 1 well ("A/1") with 4 FOVs in a 2x2 grid. +# Each FOV is 1t x 1c x 1z x 32y x 32x. +# FOVs overlap by 8 pixels (~25%) in both Y and X. +# +# Layout (pixel coordinates): +# +# .. code-block:: text +# +# FOV 0: y=[0,32), x=[0,32) FOV 1: y=[0,32), x=[24,56) +# FOV 2: y=[24,56), x=[0,32) FOV 3: y=[24,56), x=[24,56) +# +# Mosaic: 56 x 56 pixels (with 8px overlap strips between FOVs) + +tmp_dir = TemporaryDirectory() +plate_path = os.path.join(tmp_dir.name, "plate.zarr") + +rng = np.random.default_rng(123) + +# Pixel size (um/px) — use 1.0 for clean coordinate alignment +pixel_size = 1.0 + +# Grid step: 24 px → 8 px overlap per FOV pair (32 - 24 = 8) +grid_step = 24 + +# FOV grid positions: (row_idx, col_idx) → pixel origin (y, x) +fov_grid = { + "000": (0, 0), + "001": (0, 1), + "010": (1, 0), + "011": (1, 1), +} + +with open_ome_zarr(plate_path, layout="hcs", mode="w-", channel_names=["GFP"]) as plate: + for fov_name, (row_idx, col_idx) in fov_grid.items(): + pos = plate.create_position("A", "1", fov_name) + data = rng.random((1, 1, 1, 32, 32), dtype=np.float32) + pos.create_image("0", data, chunks=(1, 1, 1, 32, 32)) + + # Set physical scale and translation so FOVs are placed on a grid + y_offset = row_idx * grid_step * pixel_size + x_offset = col_idx * grid_step * pixel_size + pos.set_transform( + "0", + [ + TransformationMeta( + type="scale", + scale=[1.0, 1.0, 1.0, pixel_size, pixel_size], + ), + TransformationMeta( + type="translation", + translation=[0.0, 0.0, 0.0, y_offset, x_offset], + ), + ], + ) + +print(f"Created plate at {plate_path}") + +# %% +# Open and composite the well +# ----------------------------- +# ``Well.to_xarray()`` composites all 4 FOVs into one mosaic. + +plate = open_ome_zarr(plate_path, mode="r") +_, well = next(plate.wells()) +mosaic = well.to_xarray(compositor="mean") + +print(f"Mosaic shape: {mosaic.shape}") +print(f"Mosaic Y range: [{float(mosaic.y[0]):.2f}, {float(mosaic.y[-1]):.2f}] um") +print(f"Mosaic X range: [{float(mosaic.x[0]):.2f}, {float(mosaic.x[-1]):.2f}] um") + + +# %% +# Tile, process, and blend +# -------------------------- +# Apply a function to each tile of the mosaic and blend back. + + +def process(tile): + """Example: double the intensity.""" + return tile * 2 + + +result = apply_func_tiled( + mosaic, + fn=process, + tile_size={"y": 24, "x": 24}, + overlap={"y": 4, "x": 4}, + weights="gaussian", +) +print(f"\nResult shape: {result.shape}") +print(f"Lazy: {hasattr(result.data, 'dask')}") + +# %% +# Verify the result +# ------------------- + +values = result.values +expected = mosaic.values * 2 +np.testing.assert_allclose(values, expected, atol=1e-4) +print("Round-trip check: PASSED") + +# %% +# With overlap caching +# ---------------------- + +result_cached = apply_func_tiled( + mosaic, + fn=process, + tile_size={"y": 24, "x": 24}, + overlap={"y": 4, "x": 4}, + weights="gaussian", + cache="persist", +) +np.testing.assert_allclose(result_cached.values, expected, atol=1e-4) +print("Cached round-trip: PASSED") + +# %% +# Clean up + +plate.close() +tmp_dir.cleanup() diff --git a/docs/examples/run_tile_single_fov.py b/docs/examples/run_tile_single_fov.py new file mode 100644 index 00000000..ce987dba --- /dev/null +++ b/docs/examples/run_tile_single_fov.py @@ -0,0 +1,134 @@ +""" +Tile + Blend a Single FOV +========================== + +Create a synthetic OME-Zarr FOV, then tile it with overlap, +apply a processing function to each tile, and blend the results +back into a single mosaic using ``apply_func_tiled`` (xarray-native) +and ``tile_and_assemble`` (zarr output). +""" + +# %% +import os +import warnings +from tempfile import TemporaryDirectory + +import numpy as np + +from iohub.ngff import open_ome_zarr +from iohub.tile import apply_func_tiled, tile_and_assemble + +warnings.filterwarnings("ignore") + +# %% +# Create a synthetic single-FOV OME-Zarr +# ---------------------------------------- +# 1 timepoint, 2 channels, 4 Z-slices, 64x128 YX. + +tmp_dir = TemporaryDirectory() +fov_path = os.path.join(tmp_dir.name, "fov.zarr") + +rng = np.random.default_rng(42) +raw = rng.random((1, 2, 4, 64, 128), dtype=np.float32) + +with open_ome_zarr(fov_path, layout="fov", mode="w-", channel_names=["GFP", "DAPI"]) as dataset: + dataset.create_image("0", raw, chunks=(1, 1, 4, 64, 128)) + dataset.set_scale("0", "y", 0.325) + dataset.set_scale("0", "x", 0.325) + +print(f"Created FOV at {fov_path}") + +# %% +# Open and inspect the data +# -------------------------- + +pos = open_ome_zarr(fov_path, mode="r") +data = pos.to_xarray() +print(f"Shape: {data.shape} dims: {data.dims}") +print(f"Y range: [{float(data.y[0]):.2f}, {float(data.y[-1]):.2f}] um") +print(f"X range: [{float(data.x[0]):.2f}, {float(data.x[-1]):.2f}] um") + + +# %% +# apply_func_tiled: xarray-native (no zarr output) +# ------------------------------------------- +# Tile, apply a function, blend back. Result stays lazy until ``.values``. + + +def my_algorithm(tile): + """Example: scale by 2 and add 1.""" + return tile * 2 + 1 + + +result = apply_func_tiled( + data, + fn=my_algorithm, + tile_size={"y": 32, "x": 64}, + overlap={"y": 8, "x": 16}, + weights="gaussian", +) +print(f"Result shape: {result.shape}, lazy: {hasattr(result.data, 'dask')}") +print(f"Coords preserved: c={list(result.c.values)}") + +# Trigger computation and verify +values = result.values +expected = raw * 2 + 1 +np.testing.assert_allclose(values, expected, atol=1e-4) +print("Round-trip check: PASSED") + +# %% +# apply_func_tiled with overlap caching +# -------------------------------- +# ``cache="persist"`` pre-loads overlap strips so they aren't read twice. +# ``cache="bfs"`` reorders tile processing for cache locality. + +result_cached = apply_func_tiled( + data, + fn=my_algorithm, + tile_size={"y": 32, "x": 64}, + overlap={"y": 8, "x": 16}, + weights="gaussian", + cache="persist", +) +np.testing.assert_allclose(result_cached.values, expected, atol=1e-4) +print("Cached round-trip: PASSED") + +# %% +# tile_and_assemble: zarr output +# -------------------------------- +# Same pipeline, but writes to zarr on disk. + +out_path = os.path.join(tmp_dir.name, "result.zarr") +result_zarr = tile_and_assemble( + data, + fn=my_algorithm, + tile_size={"y": 32, "x": 64}, + output=out_path, + overlap={"y": 8, "x": 16}, + weights="gaussian", +) +print(f"Output zarr: {out_path}") +np.testing.assert_allclose(result_zarr.values, expected, atol=1e-4) +print("Zarr round-trip: PASSED") + +# %% +# Identity round-trip with different blenders +# ----------------------------------------------- +# Verify that blending is correct: ``fn=identity`` recovers the original. + +for blender in ["uniform", "gaussian", "distance"]: + r = apply_func_tiled( + data, + fn=lambda t: t, + tile_size={"y": 32, "x": 64}, + overlap={"y": 8, "x": 16}, + weights=blender, + ) + maxerr = float(np.max(np.abs(r.values - raw))) + print(f" {blender:10s} identity max error: {maxerr:.2e}") + +# %% +# Clean up + +pos.close() +tmp_dir.cleanup() diff --git a/docs/examples/run_tile_slurm_gaussian.py b/docs/examples/run_tile_slurm_gaussian.py new file mode 100644 index 00000000..892fc00a --- /dev/null +++ b/docs/examples/run_tile_slurm_gaussian.py @@ -0,0 +1,144 @@ +# /// script +# requires-python = ">=3.11" +# dependencies = [ +# "iohub @ file:///hpc/mydata/sricharan.varra/repos/iohub.tile", +# "submitit", +# "scipy", +# "numpy", +# "networkx", +# "zarr", +# "xarray", +# ] +# /// +""" +Tiled Gaussian Blur via SLURM +============================== + +Demonstrates the three-phase tile-stitch pipeline with SLURM parallelism: + + 1. ``create_tile_store`` — partition a FOV into overlapping tiles + 2. ``process_tiles`` — applied via SLURM array jobs (one job per batch) + 3. ``stitch_from_store`` — blend tile results into a final OME-Zarr + +The processing function (Gaussian blur) is defined at module level so +submitit can pickle it for SLURM workers. + +Usage:: + + uv run docs/examples/run_tile_slurm_gaussian.py +""" + +from __future__ import annotations + +import logging +import shutil +import time +from pathlib import Path + +import numpy as np +import submitit +from scipy.ndimage import gaussian_filter + +from iohub.ngff import open_ome_zarr +from iohub.tile import create_tile_store, process_tiles, stitch_from_store + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Paths and parameters — edit these for your dataset +# --------------------------------------------------------------------------- + +INPUT = Path("/hpc/mydata/sricharan.varra/data/tile-testing/deskewed_t0_c0_pyramid_3.zarr") +TEMP_STORE = Path("/hpc/mydata/sricharan.varra/data/tile-testing/slurm_gaussian_tiles.zarr") +OUTPUT = Path("/hpc/mydata/sricharan.varra/data/tile-testing/slurm_gaussian_output.zarr") +SLURM_LOGS = Path("/hpc/mydata/sricharan.varra/data/tile-testing/slurm_gaussian_logs") + +TILE_SIZE = {"z": 32, "y": 128, "x": 128} +OVERLAP = {"z": 16, "y": 32, "x": 32} +TILE_BATCH_SIZE = 16 # tiles per SLURM job +SIGMA = 2.0 # Gaussian blur sigma in pixels + + +# --------------------------------------------------------------------------- +# Processing function +# Must be module-level (not a lambda or closure) to be picklable by submitit. +# --------------------------------------------------------------------------- + + +def gaussian_blur(tile): + """Apply a 3D Gaussian blur to each (T, C) slice of a tile.""" + data = tile.values.astype(np.float32) # scipy requires float32+ + out = np.zeros_like(data) + for t in range(data.shape[0]): + for c in range(data.shape[1]): + out[t, c] = gaussian_filter(data[t, c], sigma=SIGMA) + return out + + +# --------------------------------------------------------------------------- +# Orchestrator +# --------------------------------------------------------------------------- + + +def main(): + for p in (TEMP_STORE, OUTPUT): + if p.exists(): + shutil.rmtree(p) + SLURM_LOGS.mkdir(parents=True, exist_ok=True) + + pos = open_ome_zarr(str(INPUT), layout="fov") + logger.info("Input: %s shape=%s dtype=%s", INPUT.name, pos.data.shape, pos.data.dtype) + + # Phase 1: partition into tiles + batches = create_tile_store( + pos, + tile_size=TILE_SIZE, + store=str(TEMP_STORE), + overlap=OVERLAP, + tile_batch_size=TILE_BATCH_SIZE, + ) + logger.info("%d tiles → %d batches", sum(len(b) for b in batches), len(batches)) + + # Phase 2: submit one SLURM job per batch + executor = submitit.AutoExecutor(folder=str(SLURM_LOGS), cluster="slurm") + executor.update_parameters( + slurm_job_name="tile-gaussian", + slurm_partition="cpu", + slurm_mem_per_cpu="4G", + slurm_cpus_per_task=2, + slurm_array_parallelism=100, + slurm_time=15, + ) + + t0 = time.time() + jobs = [] + with submitit.helpers.clean_env(), executor.batch(): + for batch in batches: + jobs.append(executor.submit(process_tiles, pos, gaussian_blur, str(TEMP_STORE), batch)) + + logger.info("Waiting for %d jobs...", len(jobs)) + for job in submitit.helpers.as_completed(jobs): + job.result() # raises immediately if a job failed + logger.info(" Job %s done", job.job_id) + logger.info("All jobs complete in %.1fs", time.time() - t0) + + # Phase 3: stitch and blend + stitch_from_store(str(TEMP_STORE), str(OUTPUT), pos, weights="gaussian") + logger.info("Output: %s (%.1fs total)", OUTPUT.name, time.time() - t0) + + # Verify + result = open_ome_zarr(str(OUTPUT), layout="fov").data[:].astype(np.float32) + original = pos.data[:].astype(np.float32) + assert result.shape == original.shape + assert not np.allclose(result, original), "Blur should change values" + assert np.isfinite(result).all(), "Output contains NaN or inf" + logger.info("Max diff from input: %.4e PASSED", np.abs(result - original).max()) + + # Cleanup + shutil.rmtree(TEMP_STORE) + shutil.rmtree(OUTPUT) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 0c1ad874..e4b30feb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,8 @@ dependencies = [ "blosc2", "xarray>=2024.1.1", "dask[array]", + "submitit>=1.5.4", + "scipy>=1.17.0", ] dynamic = ["version"] diff --git a/src/iohub/_experimental.py b/src/iohub/_experimental.py new file mode 100644 index 00000000..21d888ea --- /dev/null +++ b/src/iohub/_experimental.py @@ -0,0 +1,124 @@ +"""Decorator for marking APIs as experimental. + +Experimental APIs may change or be removed without a deprecation cycle. +Inspired by Polars' ``@unstable`` decorator and PEP 702's ``__deprecated__`` +attribute convention. + +Usage:: + + from iohub._experimental import experimental + + + @experimental + class Tiler: ... + + + @experimental(since="0.8.0") + def some_function(): ... + +Users see a single warning per process (Python's default filter) on first use:: + + >>> from iohub.tile import Tiler + >>> tiler = Tiler(data, tile_size={"y": 1024, "x": 1024}) + ExperimentalWarning: Tiler is experimental and may change without notice. + +To suppress:: + + import warnings + from iohub._experimental import ExperimentalWarning + + warnings.filterwarnings("ignore", category=ExperimentalWarning) +""" + +from __future__ import annotations + +import functools +import warnings +from typing import TypeVar, overload + +_T = TypeVar("_T") + + +class ExperimentalWarning(FutureWarning): + """Warning emitted when an experimental API is used.""" + + pass + + +@overload +def experimental(obj: _T) -> _T: ... + + +@overload +def experimental( + obj: None = None, + *, + message: str | None = None, + since: str | None = None, +) -> _T: ... + + +def experimental( + obj=None, + *, + message: str | None = None, + since: str | None = None, +): + """Mark a class or function as experimental. + + Emits :class:`ExperimentalWarning` on first use (once per process). + Sets ``__experimental__`` attribute for introspection. + + Parameters + ---------- + message : str | None + Custom warning message. Auto-generated from the object name if None. + since : str | None + Version when the feature was introduced as experimental. + """ + + def _decorator(obj): + msg = message or f"{obj.__qualname__} is experimental. API may change between versions." + if since: + msg += f" Since version {since}." + + if isinstance(obj, type): + return _wrap_class(obj, msg) + elif callable(obj): + return _wrap_callable(obj, msg) + else: + raise TypeError(f"@experimental can only decorate classes and callables, got {type(obj)}") + + if obj is not None: + # Called as @experimental without parens + return _decorator(obj) + # Called as @experimental(...) with parens + return _decorator + + +def _wrap_class(cls: type, msg: str) -> type: + """Wrap a class to warn on instantiation.""" + cls.__experimental__ = msg + + original_init = cls.__init__ + + @functools.wraps(original_init) + def _init_wrapper(self, *args, **kwargs): + warnings.warn(msg, ExperimentalWarning, stacklevel=2) + return original_init(self, *args, **kwargs) + + cls.__init__ = _init_wrapper + return cls + + +def _wrap_callable(fn, msg: str): + """Wrap a callable to warn on each call.""" + fn.__experimental__ = msg + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + warnings.warn(msg, ExperimentalWarning, stacklevel=2) + return fn(*args, **kwargs) + + wrapper.__experimental__ = msg + return wrapper diff --git a/src/iohub/ngff/nodes.py b/src/iohub/ngff/nodes.py index bde4c389..413cf169 100644 --- a/src/iohub/ngff/nodes.py +++ b/src/iohub/ngff/nodes.py @@ -1367,6 +1367,47 @@ def to_xarray(self) -> xr.DataArray: attrs=saved_attrs, ) + def tile( + self, + tile_size: dict[str, int], + overlap: dict[str, int] | None = None, + as_xarray: bool = False, + **kwargs, + ): + """Tile this FOV in YX with optional overlap. + + Returns a :class:`~iohub.tile.Tiler` that yields + :class:`~iohub.tile.Tile` objects (or ``xr.DataArray`` + when *as_xarray=True*). + + Parameters + ---------- + tile_size : dict[str, int] + Tile size per spatial dimension, + e.g. ``{"y": 1024, "x": 1024}``. + overlap : dict[str, int] | None + Overlap between adjacent tiles, + e.g. ``{"y": 128, "x": 128}``. + as_xarray : bool + If True, iteration yields ``xr.DataArray`` instead of + ``Tile``. + **kwargs + Forwarded to :class:`~iohub.tile.Tiler` (e.g. ``mode``). + + Returns + ------- + iohub.tile.Tiler + """ + from iohub.tile import Tiler + + return Tiler( + self.to_xarray(), + tile_size=tile_size, + overlap=overlap, + as_xarray=as_xarray, + **kwargs, + ) + def write_xarray(self, data_array: xr.DataArray, image: str = "0") -> None: """Write an xarray.DataArray into this Position. @@ -1635,6 +1676,92 @@ def positions(self): """ yield from self.iteritems() + def to_xarray( + self, + layout_resolver=None, + compositor: str | object = "mean", + ) -> "xr.DataArray": + """Full well mosaic as one dask-backed xr.DataArray. + + Composites overlapping FOVs into a single array using the + sweep-line algorithm. The result is lazy — no data is loaded + until ``.compute()`` or ``.values`` is called. + + Parameters + ---------- + layout_resolver : LayoutResolver | None + Strategy for adjusting FOV coordinates (e.g. + ``StitchingYAMLResolver``). If None, uses existing + OME-NGFF coordinateTransformations. + compositor : str | Compositor + Overlap compositing strategy. Built-in: ``"mean"``, + ``"max"``, ``"first"``. Default: ``"mean"``. + + Returns + ------- + xr.DataArray + 5D labeled array spanning the full well mosaic. + """ + from iohub.tile._composite import _composite_fovs + from iohub.tile._compositors import get_compositor + + position_paths = [] + fov_xarrays = [] + for name, pos in self.positions(): + fov_xarrays.append(pos.to_xarray()) + position_paths.append(name) + + if layout_resolver is not None: + fov_xarrays = layout_resolver.resolve(fov_xarrays, position_paths) + + compositor_obj = get_compositor(compositor) if isinstance(compositor, str) else compositor + return _composite_fovs(fov_xarrays, compositor=compositor_obj) + + def tile( + self, + tile_size: dict[str, int], + overlap: dict[str, int] | None = None, + layout_resolver=None, + compositor: str | object = "mean", + as_xarray: bool = False, + **kwargs, + ): + """Tile the well mosaic in YX with optional overlap. + + Calls :meth:`to_xarray` internally, then wraps the result + in a :class:`~iohub.tile.Tiler`. + + Parameters + ---------- + tile_size : dict[str, int] + Tile size per spatial dimension, + e.g. ``{"y": 1024, "x": 1024}``. + overlap : dict[str, int] | None + Overlap between adjacent tiles. + layout_resolver : LayoutResolver | None + Translation source for FOV coordinates. + compositor : str | Compositor + Overlap compositing strategy. + as_xarray : bool + If True, iteration yields ``xr.DataArray``. + **kwargs + Forwarded to :class:`~iohub.tile.Tiler`. + + Returns + ------- + iohub.tile.Tiler + """ + from iohub.tile import Tiler + + xa = self.to_xarray(layout_resolver=layout_resolver, compositor=compositor) + return Tiler( + xa, + tile_size=tile_size, + overlap=overlap, + as_xarray=as_xarray, + **kwargs, + ) + class Row(NGFFNode): """The Zarr group level containing wells. diff --git a/src/iohub/tile/__init__.py b/src/iohub/tile/__init__.py new file mode 100644 index 00000000..5d97529f --- /dev/null +++ b/src/iohub/tile/__init__.py @@ -0,0 +1,61 @@ +"""iohub.tile — Tile, process, and reassemble large image volumes.""" + +from iohub._experimental import ExperimentalWarning +from iohub.tile._blenders import ( + BlendContext, + Blender, + DistanceBlender, + GaussianBlender, + UniformBlender, + get_blender, +) +from iohub.tile._compositors import ( + CompositeContext, + Compositor, + FirstCompositor, + MaxCompositor, + MeanCompositor, + get_compositor, +) +from iohub.tile._registry import register_strategy +from iohub.tile._resolvers import ( + LayoutResolver, + StitchingYAMLResolver, + TransformResolver, +) +from iohub.tile._tiler import SamplingMode, Tile, Tiler +from iohub.tile.tile import ( + CacheMode, + apply_func_tiled, + create_tile_store, + process_tiles, + stitch_from_store, +) + +__all__ = [ + "BlendContext", + "Blender", + "CacheMode", + "CompositeContext", + "Compositor", + "DistanceBlender", + "ExperimentalWarning", + "FirstCompositor", + "GaussianBlender", + "LayoutResolver", + "MaxCompositor", + "MeanCompositor", + "SamplingMode", + "Tile", + "Tiler", + "StitchingYAMLResolver", + "TransformResolver", + "UniformBlender", + "apply_func_tiled", + "create_tile_store", + "get_blender", + "get_compositor", + "process_tiles", + "register_strategy", + "stitch_from_store", +] diff --git a/src/iohub/tile/_assembler.py b/src/iohub/tile/_assembler.py new file mode 100644 index 00000000..2a275d1d --- /dev/null +++ b/src/iohub/tile/_assembler.py @@ -0,0 +1,319 @@ +"""Assembler — write processed tiles back to zarr with overlap blending. + +Implements the weighted accumulation pattern (from patchly's Aggregator): + + output[bbox] += tile_data * weight + weight_map[bbox] += weight + result = output / weight_map + +Both accumulator and weight map are zarr arrays on disk, enabling +out-of-core reassembly and concurrent writes from multiple workers. +""" + +from __future__ import annotations + +import logging +import threading +from itertools import product +from pathlib import Path + +import numpy as np +import xarray as xr +import zarr + +from iohub._experimental import experimental +from iohub.ngff import open_ome_zarr +from iohub.tile._blenders import Blender, get_blender +from iohub.tile._tiler import Tile, Tiler + +logger = logging.getLogger(__name__) + + +def _resolve_chunks( + dims: tuple[str, ...], + shape: tuple[int, ...], + data: xr.DataArray, + chunks: dict[str, int] | None, +) -> tuple[int, ...]: + """Resolve zarr chunk sizes from explicit dict, source data, or array shape.""" + if chunks is not None: + return tuple(chunks.get(d, s) for d, s in zip(dims, shape)) + if data.chunks is not None: + return tuple(data.chunks[data.dims.index(d)][0] if d in data.dims else s for d, s in zip(dims, shape)) + return shape + + +def _create_scratch_zarr( + store: str | zarr.Group, + name: str | None, + *, + shape: tuple[int, ...], + chunks: tuple[int, ...], + dtype: np.dtype, + fill_value: float, + shards: tuple[int, ...] | None = None, +) -> zarr.Array: + """Create a temporary zarr array for accumulation scratch space. + + TODO: Consider using iohub's zarr utilities instead of raw zarr APIs. + """ + if isinstance(store, zarr.Group): + return store.create_array( + name, + shape=shape, + chunks=chunks, + dtype=dtype, + fill_value=fill_value, + shards=shards, + ) + if shards is not None: + from zarr.codecs import ShardingCodec + + return zarr.open_array( + store, + mode="w", + shape=shape, + chunks=shards, + dtype=dtype, + fill_value=fill_value, + codecs=[ShardingCodec(chunk_shape=chunks)], + ) + return zarr.open_array( + store, + mode="w", + shape=shape, + chunks=chunks, + dtype=dtype, + fill_value=fill_value, + ) + + +@experimental +class Assembler: + """Reassemble processed tiles into an OME-Zarr with overlap blending. + + Parameters + ---------- + tiler : Tiler + The tiler used to generate the tiles. Provides mosaic shape + and overlap metadata. + output : str | Path + Path for the output OME-Zarr store. + source_position : Position + Source Position node — metadata (channel names, scale transforms) + is copied to the output OME-Zarr. + weights : str | Blender + Blending strategy. Built-in: ``"gaussian"`` (default), ``"uniform"``. + dtype : np.dtype | None + Output dtype. Defaults to the input data's dtype. + chunks : dict[str, int] | None + Chunk sizes for the output zarr, keyed by dim name. + Defaults to the tiler's data chunk sizes. + shards : dict[str, int] | None + Shard sizes for the output zarr, keyed by dim name. + When provided, multiple chunks are grouped into larger + shard files — useful for HPC parallel filesystems. + """ + + def __init__( + self, + tiler: Tiler, + output: str | Path, + source_position, + weights: str | Blender = "gaussian", + dtype: np.dtype | None = None, + chunks: dict[str, int] | None = None, + shards: dict[str, int] | None = None, + ): + self._tiler = tiler + self._blender = get_blender(weights) + data = tiler.data + self._dtype = np.dtype(dtype) if dtype is not None else data.dtype + # Accumulator needs at least float32 precision for weighted sums + self._accum_dtype = np.float32 if self._dtype.itemsize < 4 else self._dtype + self._finalized = False + self._result: xr.DataArray | None = None + self._source_position = source_position + + # Full output dims from the data (e.g. ("t", "c", "z", "y", "x")) + self._dims = tuple(str(d) for d in data.dims) + self._shape = tuple(data.sizes[d] for d in self._dims) + self._tile_dims = tiler.tile_dims + + self._chunks = _resolve_chunks(self._dims, self._shape, data, chunks) + + self._shards = None + if shards is not None: + self._shards = tuple(shards.get(d, c) for d, c in zip(self._dims, self._chunks)) + + self._output_path = Path(output) + + accum_store = str(self._output_path.parent / (self._output_path.name + ".accum")) + weight_store = str(self._output_path.parent / (self._output_path.name + ".weights")) + + self._accum = _create_scratch_zarr( + accum_store, + None, + shape=self._shape, + chunks=self._chunks, + dtype=self._accum_dtype, + fill_value=0.0, + shards=self._shards, + ) + self._weight_map = _create_scratch_zarr( + weight_store, + None, + shape=self._shape, + chunks=self._chunks, + dtype=self._accum_dtype, + fill_value=0.0, + shards=self._shards, + ) + + # Cache weight kernels by tile shape + self._weight_cache: dict[tuple[int, ...], np.ndarray] = {} + + self._overlap = tiler.overlap + + # Lock for thread-safe append (protects read-modify-write on zarr chunks) + self._lock = threading.Lock() + + def validate_parallel_safety(self) -> bool: + """Check if tiles can be safely processed in parallel. + + Returns True if no two tiles write to overlapping zarr chunks. + When False, tiles must be processed in waves (see Tiler.graph + for coloring-based wave scheduling). + """ + # Map dim name -> chunk size for tiled dimensions + dim_chunk_sizes = {d: self._chunks[self._dims.index(d)] for d in self._tile_dims} + + seen_chunks: set[tuple[int, ...]] = set() + + for tile in self._tiler: + # Compute chunk indices touched by this tile in each tiled dim + chunk_ranges = [] + for dim in self._tile_dims: + s = tile.slices[dim] + cs = dim_chunk_sizes[dim] + chunk_ranges.append(range(s.start // cs, (s.stop - 1) // cs + 1)) + + tile_chunks = set(product(*chunk_ranges)) + + if tile_chunks & seen_chunks: + return False + seen_chunks |= tile_chunks + + return True + + def _get_weight_kernel(self, tile_shape: tuple[int, ...]) -> np.ndarray: + """Get or compute the weight kernel for a tile shape.""" + if tile_shape not in self._weight_cache: + w = self._blender.weights(tile_shape, self._overlap) + self._weight_cache[tile_shape] = w.astype(self._accum_dtype) + return self._weight_cache[tile_shape] + + def append(self, tile: Tile, result: xr.DataArray | np.ndarray): + """Write a processed tile into the accumulator. + + Parameters + ---------- + tile : Tile + The tile spec (provides spatial slices into the mosaic). + result : xr.DataArray | np.ndarray + Processed tile data. Shape must match tile extent + (broadcast dimensions are supported). + """ + if self._finalized: + raise RuntimeError("Assembler already finalized. Create a new Assembler to write more tiles.") + + data = result.values if isinstance(result, xr.DataArray) else result + data = data.astype(self._accum_dtype) + + # Weight kernel covers tiled dims, broadcast over non-tiled leading dims + weight = self._get_weight_kernel(tile.tile_shape) + + # Build index: slice for tiled dims, slice(None) for others + idx = tuple(tile.slices[d] if d in tile.slices else slice(None) for d in self._dims) + + # Compute weighted data before acquiring lock + weighted_data = data * weight + + # Thread-safe read-modify-write on zarr chunks + with self._lock: + existing_accum = self._accum[idx] + self._accum[idx] = existing_accum + weighted_data + + existing_weight = self._weight_map[idx] + self._weight_map[idx] = existing_weight + weight + + def get_output(self) -> xr.DataArray: + """Finalize: normalize and write to an OME-Zarr output store. + + Creates a proper OME-Zarr with metadata copied from the source + Position. Processes chunk-by-chunk to avoid OOM. + Idempotent — safe to call multiple times. + + Returns + ------- + xr.DataArray + Result backed by the output OME-Zarr. + """ + if self._finalized and self._result is not None: + return self._result + + from copy import deepcopy + + src = self._source_position + + # Create output OME-Zarr with metadata from source + dst = open_ome_zarr( + str(self._output_path), + layout="fov", + mode="w-", + channel_names=list(src.channel_names), + ) + src_transforms = deepcopy(src.metadata.multiscales[0].datasets[0].coordinate_transformations) + dst.create_zeros( + "0", + shape=self._shape, + dtype=self._dtype, + chunks=self._chunks, + ) + dst.metadata.multiscales[0].datasets[0].coordinate_transformations = src_transforms + dst.dump_meta() + + output = dst["0"] + + # Normalize chunk-by-chunk to avoid OOM + all_chunk_slices = list(self._iter_chunk_slices()) + n_chunks = len(all_chunk_slices) + logger.info("Normalizing %d chunks...", n_chunks) + for i, slices in enumerate(all_chunk_slices): + accum_chunk = self._accum[slices] + weight_chunk = self._weight_map[slices] + + with np.errstate(divide="ignore", invalid="ignore"): + normalized = np.where( + weight_chunk > 0, + accum_chunk / weight_chunk, + 0.0, + ) + output[slices] = normalized.astype(self._dtype) + + if (i + 1) % 100 == 0 or i + 1 == n_chunks: + logger.info(" Normalized %d/%d chunks", i + 1, n_chunks) + + self._result = dst.to_xarray() + self._finalized = True + return self._result + + def _iter_chunk_slices(self): + """Yield tuple-of-slices for each chunk in the zarr array.""" + ranges = [] + for dim_size, chunk_size in zip(self._shape, self._chunks): + starts = list(range(0, dim_size, chunk_size)) + ranges.append([slice(s, min(s + chunk_size, dim_size)) for s in starts]) + + for combo in product(*ranges): + yield tuple(combo) diff --git a/src/iohub/tile/_blend.py b/src/iohub/tile/_blend.py new file mode 100644 index 00000000..9bc0a4fe --- /dev/null +++ b/src/iohub/tile/_blend.py @@ -0,0 +1,132 @@ +"""Sweep-line tile blending into a single xr.DataArray. + +Thin wrapper around :func:`_sweep_line_assemble` that applies +weighted blending in overlap regions using :class:`Blender` kernels. +""" + +from __future__ import annotations + +import numpy as np +import xarray as xr + +from iohub.tile._blenders import Blender +from iohub.tile._sweep import _CellInfo, _sweep_line_assemble +from iohub.tile._tiler import Tile, Tiler + + +def _weighted_blend( + regions: list[xr.DataArray], + weights: list[np.ndarray], +) -> np.ndarray: + """Blend regions with weights via weighted mean. + + Parameters + ---------- + regions : list[xr.DataArray] + Tile sub-regions, all the same shape (leading..., tiled...). + weights : list[np.ndarray] + N-D weight kernels over tiled dims, one per region. + + Returns + ------- + np.ndarray + Blended array with the same shape as each region. + """ + stacked = np.stack([r.values for r in regions], axis=0) + weight_stack = np.stack(weights, axis=0) + # Reshape for broadcasting: add leading dim axes after the stack axis + ndim_extra = stacked.ndim - weight_stack.ndim + for _ in range(ndim_extra): + weight_stack = np.expand_dims(weight_stack, axis=1) + numerator = (stacked * weight_stack).sum(axis=0) + denominator = weight_stack.sum(axis=0) + return numerator / denominator + + +def _blend_tiles( + tiles: list[xr.DataArray], + tile_specs: list[Tile], + blender: Blender, + tiler: Tiler, +) -> xr.DataArray: + """Blend overlapping processed tiles into a single mosaic. + + Uses sweep-line decomposition + weighted reduction (same pattern + as ``_composite_fovs``). Returns a lazy dask-backed xr.DataArray. + + Parameters + ---------- + tiles : list[xr.DataArray] + Processed tile data arrays, one per tile_spec. + tile_specs : list[Tile] + Tile objects from the Tiler (pixel-space positions). + blender : Blender + Blending strategy providing weight kernels. + tiler : Tiler + The Tiler that produced the tile_specs (for overlap info). + + Returns + ------- + xr.DataArray + Dask-backed mosaic with coordinates from the original data. + """ + if len(tiles) != len(tile_specs): + raise ValueError(f"tiles ({len(tiles)}) and tile_specs ({len(tile_specs)}) must have the same length") + + if len(tiles) == 1: + return tiles[0] + + overlap = tiler.overlap + data = tiler.data + tile_dims = tiler.tile_dims + + # Pixel-space bounding boxes from Tiles as dict[str, (start, stop)] + tile_bboxes: list[dict[str, tuple[int, int]]] = [ + {d: (s.slices[d].start, s.slices[d].stop) for d in tile_dims} for s in tile_specs + ] + + # Weight cache: avoid recomputing kernels for same tile shape + weight_cache: dict[tuple[int, ...], np.ndarray] = {} + + def _get_weight(shape: tuple[int, ...]) -> np.ndarray: + if shape not in weight_cache: + weight_cache[shape] = blender.weights(shape, overlap) + return weight_cache[shape] + + # Overlap callback: crop weight kernels and weighted-blend + def _blend_overlap( + cell_slices: list[xr.DataArray], + contributing: list[int], + info: _CellInfo, + ) -> np.ndarray: + cell_weights = [] + for idx in contributing: + full_weight = _get_weight(tile_specs[idx].tile_shape) + # Crop weight to the cell's local region within this tile + local_slices = tuple( + slice( + info.bounds[d][0] - tile_bboxes[idx][d][0], + info.bounds[d][1] - tile_bboxes[idx][d][0], + ) + for d in tile_dims + ) + cell_weights.append(full_weight[local_slices]) + return _weighted_blend(cell_slices, cell_weights) + + mosaic, global_bounds = _sweep_line_assemble(tiles, tile_bboxes, _blend_overlap, tile_dims) + + # Wrap in xarray with coordinates from original data + all_dims = tuple(str(d) for d in data.dims) + coords = {} + for d in all_dims: + if d in global_bounds: + dmin, dmax = global_bounds[d] + coords[d] = (d, data.coords[d].values[dmin:dmax], data.coords[d].attrs) + elif d in data.coords: + coords[d] = data.coords[d] + + return xr.DataArray( + mosaic, + dims=all_dims, + coords=coords, + ) diff --git a/src/iohub/tile/_blenders.py b/src/iohub/tile/_blenders.py new file mode 100644 index 00000000..53980b02 --- /dev/null +++ b/src/iohub/tile/_blenders.py @@ -0,0 +1,167 @@ +"""Blender protocol and built-in implementations for tile overlap blending. + +A Blender produces spatial weight kernels that control how overlapping tiles +are combined during reassembly. The weighted accumulation pattern +(from patchly's Aggregator): + + output[bbox] += tile_data * weight + weight_map[bbox] += weight + result = output / weight_map + +Built-in strategies are resolved by name; third-party strategies are +discoverable via the ``iohub.blenders`` entrypoint group. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from functools import reduce +from typing import Protocol, runtime_checkable + +import numpy as np + +from iohub.tile._registry import resolve_strategy +from iohub.tile._tiler import Tile + + +@dataclass +class BlendContext: + """Tile metadata passed to blenders.""" + + tile_spec: Tile + """The tile being blended.""" + + neighbors: list[int] + """Tile IDs of neighboring (overlapping) tiles.""" + + is_edge: bool + """Whether this tile touches the mosaic border.""" + + +@runtime_checkable +class Blender(Protocol): + """Produce spatial weight kernels for tile overlap blending.""" + + def weights( + self, + tile_shape: tuple[int, ...], + overlap: dict[str, int], + metadata: BlendContext | None = None, + ) -> np.ndarray: + """Return an N-D weight array for the given tile shape. + + Parameters + ---------- + tile_shape : tuple[int, ...] + Tile size per tiled dimension, e.g. ``(Y, X)`` or ``(Z, Y, X)``. + overlap : dict[str, int] + Overlap in pixels, e.g. ``{"y": 128, "x": 128}``. + metadata : BlendContext | None + Optional tile context (neighbors, edge status). + + Returns + ------- + np.ndarray + Float64 weight array with shape ``tile_shape``. + """ + ... + + +def _separable_nd(kernels_1d: list[np.ndarray]) -> np.ndarray: + """Compute N-D separable kernel from a list of 1D kernels. + + Uses ``np.multiply.outer`` iteratively to build the outer product + of all 1D kernels. For 2D this is equivalent to ``np.outer(a, b)``. + """ + return reduce(np.multiply.outer, kernels_1d) + + +class UniformBlender: + """Uniform weights — simple averaging in overlap regions.""" + + def weights( + self, + tile_shape: tuple[int, ...], + overlap: dict[str, int], + metadata: BlendContext | None = None, + ) -> np.ndarray: + return np.ones(tile_shape, dtype=np.float64) + + +class GaussianBlender: + """Gaussian weight kernel via separable 1D gaussians. + + Center-weighted blending: tiles contribute most from their centers + and least from their edges, producing smooth transitions in overlap + regions. Default sigma is tile_size / 8 (per patchly convention). + """ + + def __init__(self, sigma_fraction: float = 1.0 / 8.0): + self._sigma_fraction = sigma_fraction + + def weights( + self, + tile_shape: tuple[int, ...], + overlap: dict[str, int], + metadata: BlendContext | None = None, + ) -> np.ndarray: + kernels_1d = [self._gaussian_1d(s) for s in tile_shape] + return _separable_nd(kernels_1d) + + def _gaussian_1d(self, size: int) -> np.ndarray: + """1D gaussian centered in the array.""" + sigma = size * self._sigma_fraction + center = (size - 1) / 2.0 + x = np.arange(size, dtype=np.float64) + g = np.exp(-0.5 * ((x - center) / sigma) ** 2) + return g + + +class DistanceBlender: + """Euclidean distance transform weights with cosine ramp. + + Each pixel's weight is proportional to its distance from the nearest + tile edge, with a cosine falloff for smooth transitions. Inspired by + multiview-stitcher's EDT blending and Preibisch et al. (2009). + + Produces smoother transitions than Gaussian blending because the + weight profile adapts to the tile shape rather than assuming a + fixed bell curve. + """ + + def weights( + self, + tile_shape: tuple[int, ...], + overlap: dict[str, int], + metadata: BlendContext | None = None, + ) -> np.ndarray: + kernels_1d = [self._cosine_ramp_1d(s) for s in tile_shape] + return _separable_nd(kernels_1d) + + @staticmethod + def _cosine_ramp_1d(size: int) -> np.ndarray: + """1D cosine ramp: small at edges, 1 at center.""" + if size <= 1: + return np.ones(size, dtype=np.float64) + # Distance from nearest edge, in [0.5, center] then normalized to (0, 1] + # The 0.5 offset ensures edge pixels get nonzero weight + dist = np.minimum(np.arange(size, dtype=np.float64), np.arange(size - 1, -1, -1, dtype=np.float64)) + dist = (dist + 0.5) / (size / 2.0) + dist = np.clip(dist, 0.0, 1.0) + # Cosine ramp + return (1.0 - np.cos(np.pi * dist)) / 2.0 + + +_BUILTINS: dict[str, type] = { + "uniform": UniformBlender, + "gaussian": GaussianBlender, + "distance": DistanceBlender, +} + + +def get_blender(name: str | Blender) -> Blender: + """Resolve a blender by name or pass through an object. + + Checks built-in names first, then ``iohub.blenders`` entrypoints. + """ + return resolve_strategy(name, _BUILTINS, "iohub.blenders", "blender") diff --git a/src/iohub/tile/_cache.py b/src/iohub/tile/_cache.py new file mode 100644 index 00000000..b1356072 --- /dev/null +++ b/src/iohub/tile/_cache.py @@ -0,0 +1,160 @@ +"""Graph-informed overlap caching for tile processing. + +Uses the Tiler's neighborhood graph to identify overlapping regions +and optimize tile processing order for cache locality. +""" + +from __future__ import annotations + +from collections import deque +from functools import reduce +from operator import mul + +import dask +import xarray as xr + +from iohub.tile._tiler import Tiler + + +def _overlap_regions(tiler: Tiler) -> list[dict[str, slice]]: + """Compute unique overlap strips from the neighborhood graph. + + For each edge ``(tile_a, tile_b)`` in the graph, the overlap is + the bounding-box intersection across all tiled dimensions. Strips + covering identical pixel ranges are deduplicated. + + Parameters + ---------- + tiler : Tiler + Tiler with overlap > 0 in at least one dimension. + + Returns + ------- + list[dict[str, slice]] + Overlap regions as dicts of slices keyed by dim name. + """ + tile_dims = tiler.tile_dims + seen: set[tuple[tuple[str, int, int], ...]] = set() + regions: list[dict[str, slice]] = [] + + for a, b in tiler.graph.edges(): + spec_a, spec_b = tiler[a], tiler[b] + + overlap_slices: dict[str, slice] = {} + valid = True + for dim in tile_dims: + start = max(spec_a.slices[dim].start, spec_b.slices[dim].start) + end = min(spec_a.slices[dim].stop, spec_b.slices[dim].stop) + if end <= start: + valid = False + break + overlap_slices[dim] = slice(start, end) + + if not valid: + continue + + key = tuple((d, s.start, s.stop) for d, s in overlap_slices.items()) + if key not in seen: + seen.add(key) + regions.append(overlap_slices) + + return regions + + +def _bfs_tile_order(tiler: Tiler) -> list[int]: + """BFS traversal of the tile neighborhood graph. + + Processes adjacent tiles consecutively so their shared overlap + chunks stay hot in cache. Starts from tile 0. + + Parameters + ---------- + tiler : Tiler + Tiler with a neighborhood graph. + + Returns + ------- + list[int] + Tile IDs in BFS order. + """ + graph = tiler.graph + if graph.number_of_nodes() == 0: + return [] + + visited: set[int] = set() + order: list[int] = [] + queue: deque[int] = deque([0]) + visited.add(0) + + while queue: + node = queue.popleft() + order.append(node) + for neighbor in sorted(graph.neighbors(node)): + if neighbor not in visited: + visited.add(neighbor) + queue.append(neighbor) + + # Include any disconnected tiles (no overlap) + for tile_id in range(len(tiler)): + if tile_id not in visited: + order.append(tile_id) + + return order + + +def _estimate_overlap_bytes(tiler: Tiler) -> int: + """Estimate total bytes needed to cache all overlap regions. + + Parameters + ---------- + tiler : Tiler + Tiler with overlap. + + Returns + ------- + int + Approximate byte count for all overlap strips. + """ + data = tiler.data + tile_dims = tiler.tile_dims + + # Leading dims: all dims NOT in tiled dims + leading = 1 + for dim in data.dims: + if dim not in tile_dims: + leading *= data.sizes[dim] + itemsize = data.dtype.itemsize + + total = 0 + for overlap_slices in _overlap_regions(tiler): + region_size = reduce( + mul, + (s.stop - s.start for s in overlap_slices.values()), + 1, + ) + total += leading * region_size * itemsize + return total + + +def _persist_overlaps(tiler: Tiler) -> None: + """Pre-compute and cache overlap regions in memory. + + Slices each overlap strip from the source data and calls + ``dask.persist()`` so they are loaded once and shared across + tiles that read the same region. + + Parameters + ---------- + tiler : Tiler + Tiler with overlap > 0. + """ + regions = _overlap_regions(tiler) + if not regions: + return + + data = tiler.data + strips: list[xr.DataArray] = [] + for overlap_slices in regions: + strips.append(data.isel(**overlap_slices)) + + dask.persist(*strips) diff --git a/src/iohub/tile/_composite.py b/src/iohub/tile/_composite.py new file mode 100644 index 00000000..4712e5f0 --- /dev/null +++ b/src/iohub/tile/_composite.py @@ -0,0 +1,116 @@ +"""Sweep-line FOV compositing into a single xr.DataArray. + +Thin wrapper around :func:`_sweep_line_assemble` that handles +physical-to-pixel coordinate conversion and compositor dispatch. +""" + +from __future__ import annotations + +import numpy as np +import xarray as xr + +from iohub.tile._compositors import CompositeContext, Compositor +from iohub.tile._sweep import _CellInfo, _sweep_line_assemble + + +def _pixel_spacing(coords: np.ndarray) -> float: + """Infer pixel spacing from coordinate array.""" + if len(coords) > 1: + return float(coords[1] - coords[0]) + return 1.0 + + +def _composite_fovs( + fov_xarrays: list[xr.DataArray], + compositor: Compositor, +) -> xr.DataArray: + """Composite N FOV xarrays into one mosaic xr.DataArray. + + Uses sweep-line decomposition to partition the mosaic into + non-overlapping cells, composites overlaps via *compositor*, + and assembles with ``dask.array.block()``. + + Compositing dimensions are determined automatically: Y and X are + always composited; Z is included when FOVs have Z coordinates + that vary across FOVs. + + Parameters + ---------- + fov_xarrays : list[xr.DataArray] + FOV data arrays, each with physical coordinates. + compositor : Compositor + Strategy for combining overlapping regions. + + Returns + ------- + xr.DataArray + Dask-backed mosaic with physical coordinates. + """ + if len(fov_xarrays) == 1: + return fov_xarrays[0] + + first = fov_xarrays[0] + + # Determine which spatial dims to composite over. + # Y and X always; Z only if FOVs have varying Z coordinates. + composite_dims: list[str] = [] + spacings: dict[str, float] = {} + + for dim in ("z", "y", "x"): + if dim not in first.dims or dim not in first.coords: + continue + spacing = _pixel_spacing(first.coords[dim].values) + if dim in ("y", "x"): + # Always composite over Y and X + composite_dims.append(dim) + spacings[dim] = spacing + elif dim == "z": + # Only composite over Z if FOVs have different Z origins + z_origins = {float(xa.coords["z"].values[0]) for xa in fov_xarrays} + if len(z_origins) > 1: + composite_dims.append(dim) + spacings[dim] = spacing + + tile_dims = tuple(composite_dims) + + # Derive pixel-space bounding boxes from physical coords + fov_bboxes: list[dict[str, tuple[int, int]]] = [] + for xa in fov_xarrays: + bbox: dict[str, tuple[int, int]] = {} + for dim in tile_dims: + s = spacings[dim] + start = round(float(xa.coords[dim].values[0]) / s) + bbox[dim] = (start, start + xa.sizes[dim]) + fov_bboxes.append(bbox) + + # Overlap callback: build CompositeContext and delegate to compositor + def _composite_overlap( + cell_slices: list[xr.DataArray], + contributing: list[int], + info: _CellInfo, + ) -> np.ndarray: + ctx = CompositeContext( + overlap_bounds=info.bounds, + fov_bounds=[fov_bboxes[idx] for idx in contributing], + ) + return compositor.composite(cell_slices, masks=None, metadata=ctx) + + mosaic, global_bounds = _sweep_line_assemble(fov_xarrays, fov_bboxes, _composite_overlap, tile_dims) + + # Wrap in xarray with physical coordinates + all_dims = tuple(str(d) for d in first.dims) + + coords: dict = {} + for d in all_dims: + if d in global_bounds: + dmin, dmax = global_bounds[d] + s = spacings[d] + coords[d] = (d, np.arange(dmin, dmax) * s, first.coords[d].attrs) + elif d in first.coords: + coords[d] = first.coords[d] + + return xr.DataArray( + mosaic, + dims=all_dims, + coords=coords, + ) diff --git a/src/iohub/tile/_compositors.py b/src/iohub/tile/_compositors.py new file mode 100644 index 00000000..ec44e423 --- /dev/null +++ b/src/iohub/tile/_compositors.py @@ -0,0 +1,93 @@ +"""Compositor protocol and built-in implementations for FOV overlap compositing. + +A Compositor controls how overlapping FOV regions are combined when +``_composite_fovs()`` builds the mosaic dask graph. Built-in strategies +are resolved by name; third-party strategies are discoverable via the +``iohub.compositors`` entrypoint group. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Protocol, runtime_checkable + +import numpy as np +import xarray as xr + +from iohub.tile._registry import resolve_strategy + + +@dataclass +class CompositeContext: + """Context passed to compositors about the overlap region.""" + + overlap_bounds: dict[str, tuple[int, int]] + """Pixel-space bounds of the overlap cell, e.g. ``{"y": (0, 32), "x": (0, 64)}``.""" + + fov_bounds: list[dict[str, tuple[int, int]]] + """Each contributing FOV's full bounds in pixel space.""" + + +@runtime_checkable +class Compositor(Protocol): + """Combine data from overlapping FOVs into a single output region.""" + + def composite( + self, + regions: list[xr.DataArray], + masks: list[np.ndarray] | None, + metadata: CompositeContext, + ) -> np.ndarray: ... + + +class MeanCompositor: + """Simple average in overlap regions (default).""" + + def composite( + self, + regions: list[xr.DataArray], + masks: list[np.ndarray] | None, + metadata: CompositeContext, + ) -> np.ndarray: + stacked = xr.concat(regions, dim="__fov__") + return stacked.mean(dim="__fov__").values + + +class MaxCompositor: + """Maximum intensity projection across FOVs.""" + + def composite( + self, + regions: list[xr.DataArray], + masks: list[np.ndarray] | None, + metadata: CompositeContext, + ) -> np.ndarray: + stacked = xr.concat(regions, dim="__fov__") + return stacked.max(dim="__fov__").values + + +class FirstCompositor: + """First FOV wins (no blending). Fastest.""" + + def composite( + self, + regions: list[xr.DataArray], + masks: list[np.ndarray] | None, + metadata: CompositeContext, + ) -> np.ndarray: + return regions[0].values + + +_BUILTINS: dict[str, type] = { + "mean": MeanCompositor, + "max": MaxCompositor, + "first": FirstCompositor, +} + + +def get_compositor(name: str | Compositor) -> Compositor: + """Resolve a compositor by name or pass through an object. + + Checks built-in names first, then ``iohub.compositors`` entrypoints. + """ + return resolve_strategy(name, _BUILTINS, "iohub.compositors", "compositor") diff --git a/src/iohub/tile/_registry.py b/src/iohub/tile/_registry.py new file mode 100644 index 00000000..b038b593 --- /dev/null +++ b/src/iohub/tile/_registry.py @@ -0,0 +1,87 @@ +"""Shared entrypoint-based strategy resolution for plugin protocols. + +Used by both compositors and blenders to resolve names to instances +via built-in dicts and ``importlib.metadata`` entrypoints. +""" + +from __future__ import annotations + +from importlib.metadata import entry_points + +_RUNTIME_REGISTRY: dict[str, dict[str, type]] = {} + + +def register_strategy( + name: str, + cls: type, + entrypoint_group: str, + *, + overwrite: bool = False, + aliases: list[str] | None = None, +) -> None: + """Register a strategy class at runtime. + + Parameters + ---------- + name : str + Primary name for the strategy. + cls : type + The strategy class (must be instantiable with no args). + entrypoint_group : str + The entrypoint group, e.g. ``"iohub.blenders"``. + overwrite : bool + If False (default), raises ValueError if name is already registered. + aliases : list[str] | None + Optional alternative names. + """ + group = _RUNTIME_REGISTRY.setdefault(entrypoint_group, {}) + for n in [name] + (aliases or []): + if n in group and not overwrite: + raise ValueError(f"Strategy {n!r} already registered in {entrypoint_group}. Use overwrite=True to replace.") + group[n] = cls + + +def _clear_runtime_registry(entrypoint_group: str) -> None: + """Remove all runtime-registered strategies for a group. For testing.""" + _RUNTIME_REGISTRY.pop(entrypoint_group, None) + + +def resolve_strategy( + name: str | object, + builtins: dict[str, type], + entrypoint_group: str, + kind: str, +) -> object: + """Resolve a strategy by name or pass through an existing instance. + + Parameters + ---------- + name : str | object + Strategy name (looked up in builtins then entrypoints) + or an already-instantiated strategy object (returned as-is). + builtins : dict[str, type] + Built-in name → class mapping. + entrypoint_group : str + Entrypoint group to search, e.g. ``"iohub.blenders"``. + kind : str + Human-readable label for error messages (e.g. ``"blender"``). + """ + if not isinstance(name, str): + return name + + runtime = _RUNTIME_REGISTRY.get(entrypoint_group, {}) + if name in runtime: + return runtime[name]() + + # Builtins + if name in builtins: + return builtins[name]() + + # Entrypoints + eps = entry_points(group=entrypoint_group) + for ep in eps: + if ep.name == name: + return ep.load()() + + available = list(runtime) + list(builtins) + [ep.name for ep in eps] + raise ValueError(f"Unknown {kind}: {name!r}. Available: {available}") diff --git a/src/iohub/tile/_resolvers.py b/src/iohub/tile/_resolvers.py new file mode 100644 index 00000000..8f1e8698 --- /dev/null +++ b/src/iohub/tile/_resolvers.py @@ -0,0 +1,137 @@ +"""LayoutResolver protocol and implementations for FOV translation sources. + +A LayoutResolver adjusts FOV xarray coordinates based on external +translation data (stitching YAML, position naming convention, etc.). +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Protocol, runtime_checkable + +import numpy as np +import xarray as xr +import yaml + + +@runtime_checkable +class LayoutResolver(Protocol): + """Adjust FOV coordinates based on external translation sources.""" + + def resolve( + self, + fov_xarrays: list[xr.DataArray], + position_paths: list[str], + ) -> list[xr.DataArray]: + """Return FOV xarrays with updated YX coordinates. + + Parameters + ---------- + fov_xarrays : list[xr.DataArray] + FOV data arrays with existing coordinates. + position_paths : list[str] + Position paths within the zarr store (e.g. ``"000000"``). + """ + ... + + +class TransformResolver: + """No-op resolver - trusts existing OME-NGFF coordinateTransformations. + + Use this when Position metadata already has correct translations + (i.e. ``Position.to_xarray()`` already produces correct coordinates). + """ + + def resolve( + self, + fov_xarrays: list[xr.DataArray], + position_paths: list[str], + ) -> list[xr.DataArray]: + return fov_xarrays + + +class StitchingYAMLResolver: + """Reads ZYX pixel translations from a stitching YAML config. + + YAML format:: + + total_translation: + well/row/posname: + - z_pixels + - y_pixels + - x_pixels + + The resolver matches position paths against YAML keys and updates + the xarray z/y/x coordinates accordingly. + + Parameters + ---------- + yaml_path : str | Path + Path to the stitching YAML file. + well_path : str | None + Well prefix for matching YAML keys (e.g. ``"0/1"``). + If None, tries to infer from YAML keys. + """ + + def __init__(self, yaml_path: str | Path, well_path: str | None = None): + self._yaml_path = Path(yaml_path) + self._well_path = well_path + with open(self._yaml_path) as f: + self._config = yaml.safe_load(f) + self._translations = self._config["total_translation"] + + def resolve( + self, + fov_xarrays: list[xr.DataArray], + position_paths: list[str], + ) -> list[xr.DataArray]: + result = [] + for xa, pos_path in zip(fov_xarrays, position_paths): + # Build lookup key: well_path/pos_path or just pos_path + key = f"{self._well_path}/{pos_path}" if self._well_path else pos_path + + if key not in self._translations: + # Try matching by position name suffix + key = self._find_matching_key(pos_path) + + if key is None: + raise KeyError( + f"Position {pos_path!r} not found in stitching YAML. " + f"Available keys: {list(self._translations.keys())[:5]}..." + ) + + z_px, y_px, x_px = self._translations[key] + + # Infer pixel size from coordinate spacing + y_coords = xa.coords["y"].values + x_coords = xa.coords["x"].values + sy = float(y_coords[1] - y_coords[0]) if len(y_coords) > 1 else 1.0 + sx = float(x_coords[1] - x_coords[0]) if len(x_coords) > 1 else 1.0 + + # Build new coordinates with translation applied + new_y = np.arange(len(y_coords)) * sy + y_px * sy + new_x = np.arange(len(x_coords)) * sx + x_px * sx + + coord_updates = { + "y": ("y", new_y, xa.coords["y"].attrs), + "x": ("x", new_x, xa.coords["x"].attrs), + } + + # Apply Z translation if z coordinate has spacing info + if "z" in xa.coords and len(xa.coords["z"].values) > 1: + z_coords = xa.coords["z"].values + sz = float(z_coords[1] - z_coords[0]) + new_z = np.arange(len(z_coords)) * sz + z_px * sz + coord_updates["z"] = ("z", new_z, xa.coords["z"].attrs) + + xa = xa.assign_coords(**coord_updates) + result.append(xa) + + return result + + def _find_matching_key(self, pos_path: str) -> str | None: + """Find a YAML key ending with the position path.""" + for key in self._translations: + if key.endswith(f"/{pos_path}") or key == pos_path: + return key + return None diff --git a/src/iohub/tile/_sweep.py b/src/iohub/tile/_sweep.py new file mode 100644 index 00000000..e9c2a9fb --- /dev/null +++ b/src/iohub/tile/_sweep.py @@ -0,0 +1,181 @@ +"""Shared sweep-line decomposition for mosaic assembly. + +Partitions a set of overlapping rectangular regions into a non-overlapping +cell grid via sweep-line decomposition, then assembles them into a single +dask array with ``da.block()``. + +Used by both ``_composite_fovs`` (FOV stitching) and ``_blend_tiles`` +(tile blending) — the only difference is the overlap handler callback. + +Supports N-D tiling dimensions (e.g. YX or ZYX). Non-tiled leading +dimensions (e.g. T, C) are passed through unchanged. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from itertools import product +from typing import Protocol, runtime_checkable + +import dask +import dask.array as da +import numpy as np +import xarray as xr + + +@dataclass(frozen=True, slots=True) +class _CellInfo: + """Pixel-space bounds of a single cell in the sweep-line grid.""" + + bounds: dict[str, tuple[int, int]] + + +@runtime_checkable +class OverlapHandler(Protocol): + """Resolve overlap where 2+ regions cover the same cell. + + Implementations receive the cropped sub-regions, the indices of the + contributing regions (into the original list), and the cell's + pixel-space bounds. They must return the blended/composited result + as a numpy array with the same shape as each sub-region. + """ + + def __call__( + self, + cell_slices: list[xr.DataArray], + contributing: list[int], + info: _CellInfo, + ) -> np.ndarray: ... + + +def _sweep_line_assemble( + regions: list[xr.DataArray], + bboxes: list[dict[str, tuple[int, int]]], + overlap_fn: OverlapHandler, + tile_dims: tuple[str, ...], +) -> tuple[da.Array, dict[str, tuple[int, int]]]: + """N-D sweep-line decomposition + ``da.block()`` assembly. + + Parameters + ---------- + regions : list[xr.DataArray] + Input data arrays (FOVs or processed tiles). + bboxes : list[dict[str, tuple[int, int]]] + Pixel-space bounding boxes per region, keyed by dim name. + e.g. ``{"y": (0, 256), "x": (0, 512)}`` or + ``{"z": (0, 32), "y": (0, 256), "x": (0, 512)}``. + overlap_fn : OverlapHandler + Handler for cells with 2+ contributors. + tile_dims : tuple[str, ...] + Ordered tiling dimensions, e.g. ``("y", "x")`` or ``("z", "y", "x")``. + + Returns + ------- + tuple[da.Array, dict[str, tuple[int, int]]] + ``(mosaic_dask, global_bounds)`` where global_bounds maps each + tiled dim to ``(min, max)``. + """ + # Global bounds per tiled dimension + global_bounds: dict[str, tuple[int, int]] = {} + for dim in tile_dims: + dim_min = min(b[dim][0] for b in bboxes) + dim_max = max(b[dim][1] for b in bboxes) + global_bounds[dim] = (dim_min, dim_max) + + # Compute unique edges per tiled dimension + edges_per_dim: dict[str, list[int]] = {} + for dim in tile_dims: + edges = sorted({coord for b in bboxes for coord in (b[dim][0], b[dim][1])}) + edges_per_dim[dim] = edges + + # Number of cells per dimension + n_cells_per_dim = {dim: len(edges) - 1 for dim, edges in edges_per_dim.items()} + + # Leading dims: all dims from the first region that are NOT tiled + first = regions[0] + leading_dims = [d for d in first.dims if d not in tile_dims] + leading_shape = tuple(first.sizes[d] for d in leading_dims) + dtype = first.dtype + + # Build N-D block grid + # We iterate over all cell indices (cartesian product of per-dim cell ranges) + dim_cell_ranges = [range(n_cells_per_dim[d]) for d in tile_dims] + + # We'll build a flat dict mapping cell index tuple -> da.Array block, + # then reshape into nested lists for da.block() + blocks: dict[tuple[int, ...], da.Array] = {} + + for cell_idx in product(*dim_cell_ranges): + # Cell bounds for each tiled dim + cell_bounds: dict[str, tuple[int, int]] = {} + cell_spatial_shape: list[int] = [] + for i, dim in enumerate(tile_dims): + start = edges_per_dim[dim][cell_idx[i]] + end = edges_per_dim[dim][cell_idx[i] + 1] + cell_bounds[dim] = (start, end) + cell_spatial_shape.append(end - start) + + # Which regions fully cover this cell? + contributing: list[int] = [] + for idx, bbox in enumerate(bboxes): + covers = all( + bbox[dim][0] <= cell_bounds[dim][0] and bbox[dim][1] >= cell_bounds[dim][1] for dim in tile_dims + ) + if covers: + contributing.append(idx) + + cell_shape = leading_shape + tuple(cell_spatial_shape) + + if len(contributing) == 0: + block = da.full(cell_shape, np.nan, dtype=dtype) + + elif len(contributing) == 1: + region = regions[contributing[0]] + local_slices = { + dim: slice( + cell_bounds[dim][0] - bboxes[contributing[0]][dim][0], + cell_bounds[dim][1] - bboxes[contributing[0]][dim][0], + ) + for dim in tile_dims + } + block = region.isel(**local_slices).data + + else: + cell_slices: list[xr.DataArray] = [] + for idx in contributing: + region = regions[idx] + local_slices = { + dim: slice( + cell_bounds[dim][0] - bboxes[idx][dim][0], + cell_bounds[dim][1] - bboxes[idx][dim][0], + ) + for dim in tile_dims + } + cell_slices.append(region.isel(**local_slices)) + + info = _CellInfo(bounds=cell_bounds) + result = dask.delayed(overlap_fn)(cell_slices, contributing, info) + block = da.from_delayed(result, shape=cell_shape, dtype=dtype) + + blocks[cell_idx] = block + + # Reshape flat dict into nested list structure for da.block() + grid_shape = tuple(n_cells_per_dim[d] for d in tile_dims) + mosaic = _build_nested_block(blocks, grid_shape, depth=0) + mosaic = da.block(mosaic) + + return mosaic, global_bounds + + +def _build_nested_block( + blocks: dict[tuple[int, ...], da.Array], + grid_shape: tuple[int, ...], + depth: int, + prefix: tuple[int, ...] = (), +): + """Recursively build nested list structure for da.block() from flat dict.""" + if depth == len(grid_shape) - 1: + # Innermost dimension: return a list of blocks + return [blocks[prefix + (i,)] for i in range(grid_shape[depth])] + else: + return [_build_nested_block(blocks, grid_shape, depth + 1, prefix + (i,)) for i in range(grid_shape[depth])] diff --git a/src/iohub/tile/_tiler.py b/src/iohub/tile/_tiler.py new file mode 100644 index 00000000..d48c9b5e --- /dev/null +++ b/src/iohub/tile/_tiler.py @@ -0,0 +1,357 @@ +"""Tile generation for xr.DataArray with overlap support. + +Produces Tile objects that lazily slice an xr.DataArray mosaic. +Inspired by xbatcher (issue #172 Tiler/Batcher decomposition) and +patchly's SamplingMode strategies. +""" + +from __future__ import annotations + +from enum import Enum +from functools import reduce +from itertools import product +from operator import mul +from typing import Iterator + +import networkx as nx +import numpy as np +import xarray as xr + +from iohub._experimental import experimental + + +class SamplingMode(Enum): + """How to handle the last tile when the array doesn't divide evenly.""" + + EDGE = "edge" + """Align last tile to array edge (may increase overlap for last tile).""" + + SQUEEZE = "squeeze" + """Redistribute overlap evenly across all tiles (default). From patchly.""" + + CROP = "crop" + """Discard tiles that extend beyond the array border.""" + + +def _gen_slices_1d( + dim_size: int, + tile_size: int, + overlap: int = 0, + mode: SamplingMode = SamplingMode.SQUEEZE, +) -> list[int]: + """Generate tile start positions for one dimension. + + Parameters + ---------- + dim_size : int + Size of the dimension in pixels. + tile_size : int + Size of each tile in pixels. + overlap : int + Overlap between adjacent tiles in pixels. + mode : SamplingMode + Border handling strategy. + + Returns + ------- + list[int] + Start positions for each tile. + """ + if tile_size > dim_size: + return [0] + if overlap >= tile_size: + raise ValueError(f"overlap ({overlap}) must be less than tile_size ({tile_size})") + + stride = tile_size - overlap + positions = list(range(0, dim_size - tile_size + 1, stride)) + + if mode == SamplingMode.SQUEEZE: + last_end = positions[-1] + tile_size + if last_end < dim_size: + # Need one more tile; redistribute positions evenly + n = len(positions) + 1 + max_start = dim_size - tile_size + if n == 1: + positions = [0] + else: + positions = [round(i * max_start / (n - 1)) for i in range(n)] + elif last_end > dim_size and len(positions) > 1: + # Squeeze existing positions so last tile fits exactly + max_start = dim_size - tile_size + n = len(positions) + positions = [round(i * max_start / (n - 1)) for i in range(n)] + + elif mode == SamplingMode.EDGE: + last_end = positions[-1] + tile_size + if last_end < dim_size: + positions.append(dim_size - tile_size) + + elif mode == SamplingMode.CROP: + pass # positions already only include full tiles + + return positions + + +def _ceil_to_multiple(value: int, multiple: int) -> int: + """Round up to the nearest multiple.""" + return ((value + multiple - 1) // multiple) * multiple + + +class Tile: + """Metadata for a single tile. Holds slices into the parent xr.DataArray. + + Use ``to_xarray()`` to get the tile data as a labeled xr.DataArray + (dask-backed, lazy until ``.compute()``). + + Parameters + ---------- + tile_id : int + Unique identifier for this tile. + slices : dict[str, slice] + Dimension name to slice mapping, e.g. + ``{"y": slice(0, 256), "x": slice(0, 256)}`` or + ``{"z": slice(0, 32), "y": slice(0, 256), "x": slice(0, 256)}``. + data : xr.DataArray + Back-reference to the mosaic xarray (set by Tiler). + """ + + __slots__ = ("tile_id", "slices", "_data") + + def __init__( + self, + tile_id: int, + slices: dict[str, slice], + data: xr.DataArray, + ): + self.tile_id = tile_id + self.slices = slices + self._data = data + + def to_xarray(self) -> xr.DataArray: + """Slice the mosaic xarray for this tile. + + Returns a dask-backed DataArray preserving all non-tiled dimensions + and physical coordinates (subset of the mosaic's global coords). + """ + return self._data.isel(**self.slices) + + @property + def tile_dims(self) -> tuple[str, ...]: + """Dimension names being tiled, e.g. ``("z", "y", "x")``.""" + return tuple(self.slices.keys()) + + @property + def bbox(self) -> np.ndarray: + """Bounding box as ``[[dim0_start, dim0_stop], ...]``.""" + return np.array([[s.start, s.stop] for s in self.slices.values()]) + + @property + def tile_shape(self) -> tuple[int, ...]: + """Tile shape in tiled dimensions.""" + return tuple(s.stop - s.start for s in self.slices.values()) + + def __repr__(self) -> str: + parts = ", ".join(f"{d}={s.start}:{s.stop}" for d, s in self.slices.items()) + return f"Tile(tile_id={self.tile_id}, {parts})" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Tile): + return NotImplemented + return self.tile_id == other.tile_id and self.slices == other.slices + + def __hash__(self) -> int: + return hash((self.tile_id,) + tuple((d, s.start, s.stop) for d, s in self.slices.items())) + + +@experimental +class Tiler: + """Generate overlapping tiles from an xr.DataArray. + + Partitions the spatial extent of the input DataArray into overlapping + tiles. Supports tiling over any subset of spatial dimensions (at + minimum "y" and "x"; optionally "z" as well). + + Iteration yields Tile objects (metadata-only). Use ``iter_xarrays()`` + to iterate over lazy xr.DataArrays instead. + + Parameters + ---------- + data : xr.DataArray + The mosaic to tile. Must have "y" and "x" dimensions. + tile_size : dict[str, int] + Tile size per dimension, e.g. ``{"y": 1024, "x": 1024}`` or + ``{"z": 32, "y": 1024, "x": 1024}``. + overlap : dict[str, int] | None + Overlap between adjacent tiles, e.g. ``{"y": 128, "x": 128}``. + mode : SamplingMode + Border handling strategy. Default: SQUEEZE. + align_to_chunks : bool + If True, snap tile_size up to the nearest multiple of the data's + zarr chunk size in each tiled dimension. Prevents partial chunk reads. + align_to_shards : bool + If True and the zarr array has shards, align to shard boundaries + instead of chunk boundaries. Takes precedence over align_to_chunks. + + Examples + -------- + >>> pos = open_ome_zarr("deskewed.zarr", mode="r") + >>> tiler = Tiler(pos.to_xarray(), tile_size={"y": 1024, "x": 1024}) + >>> len(tiler) + 15 + >>> for tile in tiler: + ... xa = tile.to_xarray() # lazy xr.DataArray + >>> tiler_3d = Tiler( + ... pos.to_xarray(), + ... tile_size={"z": 32, "y": 1024, "x": 1024}, + ... overlap={"z": 4, "y": 128, "x": 128}, + ... ) + """ + + def __init__( + self, + data: xr.DataArray, + tile_size: dict[str, int], + overlap: dict[str, int] | None = None, + mode: SamplingMode = SamplingMode.SQUEEZE, + align_to_chunks: bool = False, + align_to_shards: bool = False, + ): + if "y" not in data.dims or "x" not in data.dims: + raise ValueError(f"DataArray must have 'y' and 'x' dims, got {data.dims}") + if "y" not in tile_size or "x" not in tile_size: + raise ValueError(f"tile_size must specify 'y' and 'x', got {tile_size}") + for dim in tile_size: + if dim not in data.dims: + raise ValueError(f"tile_size key '{dim}' not found in data dims {data.dims}") + + self._data = data + self._original_tile_size = dict(tile_size) + self._overlap = overlap or {} + self._mode = mode + self._align_to_chunks = align_to_chunks + self._align_to_shards = align_to_shards + + # Determine tiled dimensions, ordered as they appear in data.dims + self._tile_dims = tuple(d for d in data.dims if d in tile_size) + + # Chunk/shard alignment: snap tile_size up to nearest multiple + tile_size = dict(tile_size) # don't mutate caller's dict + if (align_to_chunks or align_to_shards) and data.chunks is not None: + for dim in self._tile_dims: + dim_idx = list(data.dims).index(dim) + chunk_size = data.chunks[dim_idx][0] + tile_size[dim] = _ceil_to_multiple(tile_size[dim], chunk_size) + + self._tile_size = tile_size + + # Generate tile positions per dimension + self._positions_per_dim: dict[str, list[int]] = {} + for dim in self._tile_dims: + self._positions_per_dim[dim] = _gen_slices_1d( + dim_size=data.sizes[dim], + tile_size=tile_size[dim], + overlap=self._overlap.get(dim, 0), + mode=mode, + ) + + # Build Tiles from cartesian product of all tiled dim positions + self._tiles: list[Tile] = [] + dim_positions = [self._positions_per_dim[d] for d in self._tile_dims] + tile_id = 0 + for combo in product(*dim_positions): + slices = {} + for dim, start in zip(self._tile_dims, combo): + end = min(start + tile_size[dim], data.sizes[dim]) + slices[dim] = slice(start, end) + self._tiles.append(Tile(tile_id=tile_id, slices=slices, data=data)) + tile_id += 1 + + self._graph: nx.Graph | None = None + + @property + def tile_dims(self) -> tuple[str, ...]: + """Dimension names being tiled, e.g. ``("z", "y", "x")`` or ``("y", "x")``.""" + return self._tile_dims + + def __iter__(self) -> Iterator[Tile]: + yield from self._tiles + + def iter_xarrays(self) -> Iterator[xr.DataArray]: + """Iterate over tiles as lazy xr.DataArrays.""" + for tile in self._tiles: + yield tile.to_xarray() + + def __len__(self) -> int: + return len(self._tiles) + + def __getitem__(self, idx: int | slice) -> Tile | list[Tile]: + return self._tiles[idx] + + def __repr__(self) -> str: + grid_parts = "x".join(str(len(self._positions_per_dim[d])) for d in self._tile_dims) + size_str = f"tile_size={self._tile_size}" + if self._tile_size != self._original_tile_size: + size_str += f" (requested={self._original_tile_size})" + return f"Tiler(tiles={len(self)}, grid={grid_parts}, {size_str}, overlap={self._overlap})" + + @property + def data(self) -> xr.DataArray: + """The underlying mosaic xr.DataArray.""" + return self._data + + @property + def overlap(self) -> dict[str, int]: + """Overlap in pixels per dimension, e.g. ``{"y": 128, "x": 128}``.""" + return self._overlap + + @property + def graph(self) -> nx.Graph: + """Tile neighborhood graph. Nodes=tile_ids, edges=overlapping pairs. + + Built lazily on first access. Uses grid-based construction. + """ + if self._graph is None: + self._graph = self._build_neighborhood_graph() + return self._graph + + @property + def tile_grid_shape(self) -> tuple[int, ...]: + """Number of tiles per tiled dimension.""" + return tuple(len(self._positions_per_dim[d]) for d in self._tile_dims) + + def _build_neighborhood_graph(self) -> nx.Graph: + """Build tile neighborhood graph using N-D grid-based lookup. + + Tiles are on a regular N-D grid (from cartesian product of + positions). Neighbors along each dimension with overlap > 0 + are connected by edges. + """ + G = nx.Graph() + grid_shape = self.tile_grid_shape + n_dims = len(grid_shape) + + for tile in self._tiles: + G.add_node(tile.tile_id) + + # Strides for converting N-D grid index to flat tile_id (row-major) + strides = [] + for i in range(n_dims): + strides.append(reduce(mul, grid_shape[i + 1 :], 1)) + + for tile in self._tiles: + # Recover N-D grid index from flat tile_id + grid_idx = [] + remaining = tile.tile_id + for s in strides: + grid_idx.append(remaining // s) + remaining %= s + + # Connect to next neighbor along each dimension with overlap + for axis, dim in enumerate(self._tile_dims): + if self._overlap.get(dim, 0) > 0 and grid_idx[axis] + 1 < grid_shape[axis]: + neighbor_idx = list(grid_idx) + neighbor_idx[axis] += 1 + neighbor_id = sum(i * s for i, s in zip(neighbor_idx, strides)) + G.add_edge(tile.tile_id, neighbor_id) + + return G diff --git a/src/iohub/tile/tile.py b/src/iohub/tile/tile.py new file mode 100644 index 00000000..10bc556a --- /dev/null +++ b/src/iohub/tile/tile.py @@ -0,0 +1,498 @@ +from __future__ import annotations + +import logging +from copy import deepcopy +from pathlib import Path +from typing import TYPE_CHECKING, Callable, Literal, overload + +import numpy as np +import xarray as xr +import zarr + +from iohub._experimental import experimental +from iohub.tile._blend import _blend_tiles +from iohub.tile._blenders import ( + Blender, + get_blender, +) +from iohub.tile._compositors import ( + Compositor, +) +from iohub.tile._resolvers import ( + LayoutResolver, +) +from iohub.tile._tiler import SamplingMode, Tile, Tiler + +if TYPE_CHECKING: + from iohub.ngff.nodes import Position, Well + +logger = logging.getLogger(__name__) + +CacheMode = Literal["persist", "bfs"] + +# Fixed well path in the temp HCS store +_WELL_ROW = "A" +_WELL_COL = "1" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _to_xarray(data: Position | Well, **kwargs) -> xr.DataArray: + """Convert Position or Well to xr.DataArray.""" + return data.to_xarray(**kwargs) + + +def _tile_position_path(store: str | Path, tile_id: int) -> Path: + """Return path to a tile's standalone FOV zarr.""" + return Path(store) / _WELL_ROW / _WELL_COL / f"tile_{tile_id}" + + +def _read_store_meta(store: str | Path) -> dict: + """Read iohub store-level metadata from temp tile store.""" + root = zarr.open_group(str(store), mode="r") + meta = dict(root.attrs.get("iohub", {})) + for required in ("tile_size", "n_tiles", "tile_dims"): + if required not in meta: + raise ValueError( + f"Temp tile store at {store} is missing required metadata key '{required}'. " + f"Was create_tile_store called for this store?" + ) + return meta + + +def _pixel_spacing(coords: np.ndarray) -> float: + """Infer pixel spacing from a physical coordinate array.""" + if len(coords) > 1: + return float(coords[1] - coords[0]) + return 1.0 + + +# --------------------------------------------------------------------------- +# Phase 1: create_tile_store +# --------------------------------------------------------------------------- + + +@experimental +def create_tile_store( + data: Position, + tile_size: dict[str, int], + store: str | Path, + *, + overlap: dict[str, int] | None = None, + tile_batch_size: int = 16, + mode: SamplingMode = SamplingMode.SQUEEZE, +) -> list[list[int]]: + """Set up a temp HCS store for tiled processing. Returns batched tile IDs. + + Creates the store directory structure and writes store-level metadata. + Each tile result will later be written as a standalone OME-Zarr FOV at + ``store/A/1/tile_{tile_id}/``. + + Parameters + ---------- + data : Position + Source OME-Zarr Position to tile. + tile_size : dict[str, int] + Tile size per dimension, e.g. ``{"z": 32, "y": 256, "x": 256}``. + store : str | Path + Path for the temp tile store. Must not already exist. + overlap : dict[str, int] | None + Overlap between tiles, e.g. ``{"z": 16, "y": 32, "x": 32}``. + tile_batch_size : int + Number of tiles per batch (one batch = one SLURM job). Default: 16. + mode : SamplingMode + Border handling strategy. Default: SQUEEZE. + + Returns + ------- + list[list[int]] + Batched tile IDs, e.g. ``[[0,1,...,15], [16,...,31], ...]``. + Pass each batch to :func:`process_tiles`. + """ + overlap = overlap or {} + store = Path(store) + + if store.exists(): + raise FileExistsError(f"Temp tile store already exists: {store}") + + xa = data.to_xarray() + tiler = Tiler(xa, tile_size=tile_size, overlap=overlap, mode=mode) + n_tiles = len(tiler) + + logger.info( + "Creating tile store: %s (%d tiles, grid=%s, tile_dims=%s)", + store, + n_tiles, + tiler.tile_grid_shape, + tiler.tile_dims, + ) + + # Create directory structure and root zarr group + well_dir = store / _WELL_ROW / _WELL_COL + well_dir.mkdir(parents=True) + + # Write store-level metadata via zarr's attr API (handles v0.4 .zattrs and v0.5 zarr.json) + root = zarr.open_group(str(store), mode="w") + root.attrs["iohub"] = { + "tile_size": tile_size, + "overlap": overlap, + "tile_dims": list(tiler.tile_dims), + "n_tiles": n_tiles, + } + + # Batch tile IDs + all_ids = list(range(n_tiles)) + batches = [all_ids[i : i + tile_batch_size] for i in range(0, n_tiles, tile_batch_size)] + + logger.info( + "Created %d batches of up to %d tiles each", + len(batches), + tile_batch_size, + ) + return batches + + +# --------------------------------------------------------------------------- +# Phase 2: process_tiles +# --------------------------------------------------------------------------- + + +@experimental +def process_tiles( + data: Position, + fn: Callable[[xr.DataArray], xr.DataArray | np.ndarray], + store: str | Path, + tile_ids: list[int], +) -> None: + """Apply fn to a batch of tiles and write results to the temp tile store. + + Designed to run as a SLURM job (via submitit). Each tile is written as + a standalone OME-Zarr Position at ``store/A/1/tile_{tile_id}/``. + + Parameters + ---------- + data : Position + Source OME-Zarr Position (same as passed to :func:`create_tile_store`). + fn : callable + Function to apply to each tile. Receives an xr.DataArray (5D TCZYX), + returns an xr.DataArray or np.ndarray of the same shape. + store : str | Path + Path to the temp tile store created by :func:`create_tile_store`. + tile_ids : list[int] + Tile IDs to process (one batch from :func:`create_tile_store`). + """ + from iohub.ngff import open_ome_zarr + from iohub.ngff.models import TransformationMeta + + store = Path(store) + meta = _read_store_meta(store) + tile_size = meta["tile_size"] + overlap = meta.get("overlap", {}) + mode_str = meta.get("mode", "squeeze") + mode = SamplingMode(mode_str) if isinstance(mode_str, str) else SamplingMode.SQUEEZE + + xa = data.to_xarray() + tiler = Tiler(xa, tile_size=tile_size, overlap=overlap, mode=mode) + tiles = list(tiler) + + # Get source scale and channel names + src_transforms = data.metadata.multiscales[0].datasets[0].coordinate_transformations + channel_names = list(data.channel_names) + + # Infer pixel spacings per tiled dim + spacings: dict[str, float] = {} + for dim in tiler.tile_dims: + if dim in xa.coords: + spacings[dim] = _pixel_spacing(xa.coords[dim].values) + else: + spacings[dim] = 1.0 + + # Full dim order for building translation + all_dims = tuple(str(d) for d in xa.dims) + + for tile_id in tile_ids: + tile = tiles[tile_id] + tile_xa = tile.to_xarray() + + result = fn(tile_xa) + if isinstance(result, np.ndarray): + result_arr = result + else: + result_arr = result.values + + # Validate result shape + if result_arr.ndim != 5: + raise ValueError(f"fn must return a 5D array (TCZYX), got shape {result_arr.shape} for tile {tile_id}") + if result_arr.shape != tile_xa.shape: + raise ValueError( + f"fn returned shape {result_arr.shape} but tile {tile_id} has shape {tile_xa.shape}. " + f"fn must preserve spatial dimensions." + ) + + # Build translation transform from tile start positions + translation = [] + for dim in all_dims: + if dim in tile.slices: + start = tile.slices[dim].start + translation.append(float(start) * spacings.get(dim, 1.0)) + else: + translation.append(0.0) + + # Build transforms: copy scale from source, set translation from tile position + scale_meta = None + for tr in src_transforms or []: + if tr.type == "scale": + scale_meta = deepcopy(tr) + break + if scale_meta is None: + scale_meta = TransformationMeta(type="scale", scale=[1.0] * len(all_dims)) + + translation_meta = TransformationMeta(type="translation", translation=translation) + + # Write tile as standalone OME-Zarr FOV + tile_path = str(_tile_position_path(store, tile_id)) + tile_meta = { + "tile_id": tile_id, + "slices": {d: [s.start, s.stop] for d, s in tile.slices.items()}, + } + pos = open_ome_zarr(tile_path, layout="fov", mode="w-", channel_names=channel_names) + pos.create_image( + "0", + result_arr, + transform=[scale_meta, translation_meta], + ) + # Write iohub metadata after image data — if this is interrupted the tile + # directory exists but stitch_from_store will raise a clear error on missing "slices". + pos.zattrs["iohub"] = tile_meta + + logger.info("Wrote tile %d to %s", tile_id, tile_path) + + +# --------------------------------------------------------------------------- +# Phase 3: stitch_from_store +# --------------------------------------------------------------------------- + + +@experimental +def stitch_from_store( + store: str | Path, + output: str | Path, + source_position: Position, + *, + weights: str | Blender = "gaussian", +) -> None: + """Blend tile results from temp store into a final output OME-Zarr. + + Reads all tile FOVs from the temp store, reconstructs their spatial + positions, blends overlapping regions, and writes the result to + ``output``. + + Parameters + ---------- + store : str | Path + Temp tile store created by :func:`create_tile_store` and populated + by :func:`process_tiles`. + output : str | Path + Path for the output OME-Zarr store (Position layout). + source_position : Position + Original source Position — used to reconstruct Tiler and copy + OME-Zarr metadata to the output. + weights : str | Blender + Blending strategy. Default: ``"gaussian"``. + """ + from iohub.ngff import open_ome_zarr + + store = Path(store) + meta = _read_store_meta(store) + tile_size = meta["tile_size"] + overlap = meta.get("overlap", {}) + n_tiles = meta["n_tiles"] + mode_str = meta.get("mode", "squeeze") + mode = SamplingMode(mode_str) if isinstance(mode_str, str) else SamplingMode.SQUEEZE + + source_xa = source_position.to_xarray() + tiler = Tiler(source_xa, tile_size=tile_size, overlap=overlap, mode=mode) + + logger.info("Stitching %d tiles from %s → %s", n_tiles, store, output) + + # Read all tile results + reconstruct Tile specs + tile_xarrays: list[xr.DataArray] = [] + tile_specs: list[Tile] = [] + + for i in range(n_tiles): + tile_path = str(_tile_position_path(store, i)) + try: + pos = open_ome_zarr(tile_path, layout="fov") + except Exception as e: + raise FileNotFoundError( + f"Tile {i} not found in store {store}. Was process_tiles run for tile_id={i}? Path: {tile_path}" + ) from e + + tile_meta = pos.zattrs.get("iohub", {}) + if "slices" not in tile_meta: + raise ValueError( + f"Tile {i} at {tile_path} is missing 'slices' in iohub metadata. " + f"The tile may have been partially written (e.g. process killed mid-write)." + ) + + slices = {d: slice(v[0], v[1]) for d, v in tile_meta["slices"].items()} + tile = Tile(tile_id=i, slices=slices, data=source_xa) + + tile_xarrays.append(pos.to_xarray()) + tile_specs.append(tile) + + # Blend + blender = get_blender(weights) + blended = _blend_tiles(tile_xarrays, tile_specs, blender, tiler) + + # Write output OME-Zarr + logger.info("Writing output to %s", output) + dst = open_ome_zarr( + str(output), + layout="fov", + mode="w-", + channel_names=list(source_position.channel_names), + ) + + shape = tuple(blended.sizes[d] for d in blended.dims) + dtype = source_position.data.dtype + + # Copy coordinate transforms from source + src_transforms = deepcopy(source_position.metadata.multiscales[0].datasets[0].coordinate_transformations) + dst.create_zeros("0", shape=shape, dtype=dtype) + dst.metadata.multiscales[0].datasets[0].coordinate_transformations = src_transforms + dst.dump_meta() + + output_arr = dst["0"] + + # Compute blended result and write chunk-by-chunk + blended_computed = blended.values.astype(dtype) + output_arr[...] = blended_computed + + logger.info("Stitch complete: %s (shape=%s)", output, shape) + + +# --------------------------------------------------------------------------- +# apply_func_tiled (in-memory, unchanged) +# --------------------------------------------------------------------------- + + +@overload +def apply_func_tiled( + data: Position, + fn: Callable[[xr.DataArray], xr.DataArray | np.ndarray], + tile_size: dict[str, int], + *, + overlap: dict[str, int] | None = ..., + weights: str | Blender = ..., + mode: SamplingMode = ..., + align_to_chunks: bool = ..., + cache: CacheMode | None = ..., +) -> xr.DataArray: ... + + +@overload +def apply_func_tiled( + data: Well, + fn: Callable[[xr.DataArray], xr.DataArray | np.ndarray], + tile_size: dict[str, int], + *, + overlap: dict[str, int] | None = ..., + weights: str | Blender = ..., + mode: SamplingMode = ..., + align_to_chunks: bool = ..., + cache: CacheMode | None = ..., + layout_resolver: LayoutResolver | None = ..., + compositor: str | Compositor = ..., +) -> xr.DataArray: ... + + +@experimental +def apply_func_tiled( + data: Position | Well, + fn: Callable[[xr.DataArray], xr.DataArray | np.ndarray], + tile_size: dict[str, int], + *, + overlap: dict[str, int] | None = None, + weights: str | Blender = "gaussian", + mode: SamplingMode = SamplingMode.SQUEEZE, + align_to_chunks: bool = False, + cache: Literal["persist", "bfs"] | None = None, + # Well-specific (forwarded to Well.to_xarray) + layout_resolver: LayoutResolver | None = None, + compositor: str | Compositor = "mean", +) -> xr.DataArray: + """Tile a volume, apply a function, and blend back in memory. + + Returns a lazy dask-backed ``xr.DataArray`` without writing to zarr. + For large volumes use :func:`create_tile_store` / :func:`process_tiles` / + :func:`stitch_from_store` instead. + + Parameters + ---------- + data : Position | Well + Input volume. + fn : callable + Function applied to each tile. + tile_size : dict[str, int] + Tile size, e.g. ``{"y": 1024, "x": 1024}``. + overlap : dict[str, int] | None + Overlap between tiles. + weights : str | Blender + Blending strategy. Default: ``"gaussian"``. + mode : SamplingMode + Border handling. Default: SQUEEZE. + align_to_chunks : bool + Snap tile boundaries to chunk multiples. + cache : ``"persist"`` | ``"bfs"`` | None + Overlap caching strategy. + + Returns + ------- + xr.DataArray + Lazy dask-backed result. + """ + from iohub.tile._cache import _bfs_tile_order, _persist_overlaps + + well_kwargs: dict = {} + if layout_resolver is not None: + well_kwargs["layout_resolver"] = layout_resolver + if compositor != "mean": + well_kwargs["compositor"] = compositor + + xa = _to_xarray(data, **well_kwargs) + + tiler = Tiler( + xa, + tile_size=tile_size, + overlap=overlap, + mode=mode, + align_to_chunks=align_to_chunks, + ) + + if cache == "persist": + _persist_overlaps(tiler) + + tile_specs = list(tiler) + if cache == "bfs": + tile_order = _bfs_tile_order(tiler) + else: + tile_order = list(range(len(tile_specs))) + + processed: list[xr.DataArray | None] = [None] * len(tile_specs) + for idx in tile_order: + tile_xa = tile_specs[idx].to_xarray() + result = fn(tile_xa) + if isinstance(result, np.ndarray): + result = xr.DataArray( + result, + dims=tile_xa.dims, + coords=tile_xa.coords, + ) + processed[idx] = result + + blender = get_blender(weights) + return _blend_tiles(processed, tile_specs, blender, tiler) diff --git a/tests/conftest.py b/tests/conftest.py index d44d7e60..8bffa3a9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -60,49 +60,44 @@ def subdirs(parent: Path, name: str) -> list[Path]: return [d for d in (parent / name).iterdir() if d.is_dir()] -test_datasets = download_data() - - -mm2gamma_ome_tiffs = subdirs(test_datasets, "MM20_ome-tiffs") - - -mm2gamma_ome_tiffs_hcs = [p for p in mm2gamma_ome_tiffs if "4p" in p.name] - - -# This is a dataset with 11 timepoints -# The MDA definition at start of the experiment specifies 20 timepoints -mm2gamma_ome_tiffs_incomplete = test_datasets / "MM20_ometiff_incomplete" / "mm2.0-20201209_20t_5z_3c_512k_incomplete_1" - - -mm2gamma_singlepage_tiffs = subdirs(test_datasets, "MM20_singlepage-tiffs") - - -# This is a dataset with 11 timepoints -# The MDA definition at start of the experiment specifies 20 timepoints -mm2gamma_singlepage_tiffs_incomplete = ( - test_datasets / "MM20_singlepage_incomplete" / "mm2.0-20201209_20t_5z_3c_512k_incomplete_1 2" -) - - -mm1422_ome_tiffs = subdirs(test_datasets, "MM1422_ome-tiffs") - - -mm1422_singlepage_tiffs = subdirs(test_datasets, "MM1422_singlepage-tiffs") - - -mm2gamma_zarr_v01 = test_datasets / "MM20_zarr" / "mm2.0-20201209_4p_2t_5z_1c_512k_1.zarr" - - -hcs_ref = test_datasets / "20200812-CardiomyocyteDifferentiation14-Cycle1.zarr" - - -ndtiff_v2_datasets = subdirs(test_datasets, "MM20_pycromanager") - - -ndtiff_v2_ptcz = test_datasets / "MM20_pycromanager" / "mm2.0-20210713_pm0.13.2_2p_3t_2c_7z_1" - - -ndtiff_v3_labeled_positions = test_datasets / "ndtiff_v3_labeled_positions" +test_datasets = None +mm2gamma_ome_tiffs = [] +mm2gamma_ome_tiffs_hcs = [] +mm2gamma_ome_tiffs_incomplete = None +mm2gamma_singlepage_tiffs = [] +mm2gamma_singlepage_tiffs_incomplete = None +mm1422_ome_tiffs = [] +mm1422_singlepage_tiffs = [] +mm2gamma_zarr_v01 = None +hcs_ref = None +ndtiff_v2_datasets = [] +ndtiff_v2_ptcz = None +ndtiff_v3_labeled_positions = None + +try: + test_datasets = download_data() + mm2gamma_ome_tiffs = subdirs(test_datasets, "MM20_ome-tiffs") + mm2gamma_ome_tiffs_hcs = [p for p in mm2gamma_ome_tiffs if "4p" in p.name] + # This is a dataset with 11 timepoints + # The MDA definition at start of the experiment specifies 20 timepoints + mm2gamma_ome_tiffs_incomplete = ( + test_datasets / "MM20_ometiff_incomplete" / "mm2.0-20201209_20t_5z_3c_512k_incomplete_1" + ) + mm2gamma_singlepage_tiffs = subdirs(test_datasets, "MM20_singlepage-tiffs") + # This is a dataset with 11 timepoints + # The MDA definition at start of the experiment specifies 20 timepoints + mm2gamma_singlepage_tiffs_incomplete = ( + test_datasets / "MM20_singlepage_incomplete" / "mm2.0-20201209_20t_5z_3c_512k_incomplete_1 2" + ) + mm1422_ome_tiffs = subdirs(test_datasets, "MM1422_ome-tiffs") + mm1422_singlepage_tiffs = subdirs(test_datasets, "MM1422_singlepage-tiffs") + mm2gamma_zarr_v01 = test_datasets / "MM20_zarr" / "mm2.0-20201209_4p_2t_5z_1c_512k_1.zarr" + hcs_ref = test_datasets / "20200812-CardiomyocyteDifferentiation14-Cycle1.zarr" + ndtiff_v2_datasets = subdirs(test_datasets, "MM20_pycromanager") + ndtiff_v2_ptcz = test_datasets / "MM20_pycromanager" / "mm2.0-20210713_pm0.13.2_2p_3t_2c_7z_1" + ndtiff_v3_labeled_positions = test_datasets / "ndtiff_v3_labeled_positions" +except Exception: + pass @pytest.fixture diff --git a/tests/tile/__init__.py b/tests/tile/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/tile/conftest.py b/tests/tile/conftest.py new file mode 100644 index 00000000..a71007e3 --- /dev/null +++ b/tests/tile/conftest.py @@ -0,0 +1,106 @@ +"""Fixtures and hypothesis strategies for iohub.tile tests.""" + +import warnings + +import hypothesis.strategies as st +import numpy as np +import pytest +import xarray as xr +from hypothesis import HealthCheck, settings + +from iohub._experimental import ExperimentalWarning +from iohub.ngff import open_ome_zarr + +# Default hypothesis settings for tile tests +settings.register_profile( + "tile", + max_examples=15, + deadline=5000, + suppress_health_check=[HealthCheck.too_slow, HealthCheck.function_scoped_fixture], +) +settings.load_profile("tile") + + +@pytest.fixture(autouse=True) +def suppress_experimental_warnings(): + """Suppress ExperimentalWarning in all tile tests.""" + with warnings.catch_warnings(): + warnings.simplefilter("ignore", ExperimentalWarning) + yield + + +@pytest.fixture +def synthetic_5d(): + """Small 5D xr.DataArray (1,1,4,64,128) with random float32 data and physical coords.""" + rng = np.random.default_rng(42) + data = rng.random((1, 1, 4, 64, 128), dtype=np.float32) + return xr.DataArray( + data, + dims=("t", "c", "z", "y", "x"), + coords={ + "y": np.arange(64, dtype=np.float64) * 0.325, + "x": np.arange(128, dtype=np.float64) * 0.325, + }, + ) + + +@pytest.fixture +def synthetic_5d_large_z(): + """5D xr.DataArray (1,1,16,64,128) with enough Z for meaningful Z-tiling.""" + rng = np.random.default_rng(99) + data = rng.random((1, 1, 16, 64, 128), dtype=np.float32) + return xr.DataArray( + data, + dims=("t", "c", "z", "y", "x"), + coords={ + "z": np.arange(16, dtype=np.float64) * 0.5, + "y": np.arange(64, dtype=np.float64) * 0.325, + "x": np.arange(128, dtype=np.float64) * 0.325, + }, + ) + + +@pytest.fixture +def synthetic_position(tmp_path, synthetic_5d): + """Position node backed by a real OME-Zarr store (1,1,4,64,128) float32.""" + path = tmp_path / "synthetic.zarr" + pos = open_ome_zarr(str(path), layout="fov", mode="w-", channel_names=["ch0"]) + pos.create_image("0", synthetic_5d.values, chunks=(1, 1, 4, 64, 128)) + return pos + + +@pytest.fixture +def synthetic_position_large_z(tmp_path, synthetic_5d_large_z): + """Position node (1,1,16,64,128) float32 with Z coords for ZYX tiling.""" + path = tmp_path / "synthetic_z.zarr" + pos = open_ome_zarr(str(path), layout="fov", mode="w-", channel_names=["ch0"]) + pos.create_image("0", synthetic_5d_large_z.values, chunks=(1, 1, 16, 64, 128)) + return pos + + +@st.composite +def tile_params(draw, y_size=64, x_size=128): + """Draw a valid (tile_size, overlap) pair for a given YX shape.""" + tile_y = draw(st.integers(8, y_size)) + tile_x = draw(st.integers(8, x_size)) + overlap_y = draw(st.integers(0, tile_y - 1)) + overlap_x = draw(st.integers(0, tile_x - 1)) + return ( + {"y": tile_y, "x": tile_x}, + {"y": overlap_y, "x": overlap_x}, + ) + + +@st.composite +def tile_params_zyx(draw, z_size=16, y_size=64, x_size=128): + """Draw a valid (tile_size, overlap) pair for ZYX tiling.""" + tile_z = draw(st.integers(4, z_size)) + tile_y = draw(st.integers(8, y_size)) + tile_x = draw(st.integers(8, x_size)) + overlap_z = draw(st.integers(0, tile_z - 1)) + overlap_y = draw(st.integers(0, tile_y - 1)) + overlap_x = draw(st.integers(0, tile_x - 1)) + return ( + {"z": tile_z, "y": tile_y, "x": tile_x}, + {"z": overlap_z, "y": overlap_y, "x": overlap_x}, + ) diff --git a/tests/tile/test_assembler.py b/tests/tile/test_assembler.py new file mode 100644 index 00000000..fd3e2aee --- /dev/null +++ b/tests/tile/test_assembler.py @@ -0,0 +1,102 @@ +"""Tests for Assembler tile reassembly.""" + +import numpy as np +import pytest + +from iohub.tile import Tiler +from iohub.tile._assembler import Assembler + + +def test_roundtrip_preserves_data(synthetic_position, tmp_path): + """Identity round-trip (tiler -> assembler) preserves data and coords.""" + original = synthetic_position.data[:] + xa = synthetic_position.to_xarray() + tiler = Tiler(xa, tile_size={"y": 32, "x": 64}, overlap={"y": 8, "x": 16}) + asm = Assembler( + tiler, + output=str(tmp_path / "out.zarr"), + source_position=synthetic_position, + weights="uniform", + ) + for tile in tiler: + asm.append(tile, tile.to_xarray()) + result = asm.get_output() + + assert result.shape == original.shape + np.testing.assert_allclose(result.values, original, atol=1e-5) + + +def test_roundtrip_gaussian(synthetic_position, tmp_path): + """Gaussian blending round-trip preserves identity.""" + original = synthetic_position.data[:] + xa = synthetic_position.to_xarray() + tiler = Tiler(xa, tile_size={"y": 32, "x": 64}, overlap={"y": 8, "x": 16}) + asm = Assembler( + tiler, + output=str(tmp_path / "gauss.zarr"), + source_position=synthetic_position, + weights="gaussian", + ) + for tile in tiler: + asm.append(tile, tile.to_xarray()) + result = asm.get_output() + + assert result.shape == original.shape + np.testing.assert_allclose(result.values, original, atol=1e-5) + + +def test_append_after_finalize_raises(synthetic_position, tmp_path): + """Appending after get_output() raises RuntimeError.""" + xa = synthetic_position.to_xarray() + tiler = Tiler(xa, tile_size={"y": 64, "x": 128}) + asm = Assembler( + tiler, + output=str(tmp_path / "out.zarr"), + source_position=synthetic_position, + weights="uniform", + ) + for tile in tiler: + asm.append(tile, tile.to_xarray()) + asm.get_output() + with pytest.raises(RuntimeError, match="already finalized"): + asm.append(tiler[0], tiler[0].to_xarray()) + + +def test_parallel_safety(synthetic_position, tmp_path): + """validate_parallel_safety detects chunk conflicts from overlapping tiles.""" + xa = synthetic_position.to_xarray() + + tiler_safe = Tiler(xa, tile_size={"y": 32, "x": 64}) + asm_safe = Assembler( + tiler_safe, + output=str(tmp_path / "a.zarr"), + source_position=synthetic_position, + chunks={"y": 32, "x": 64}, + ) + assert asm_safe.validate_parallel_safety() is True + + tiler_unsafe = Tiler(xa, tile_size={"y": 32, "x": 64}, overlap={"y": 8, "x": 16}) + asm_unsafe = Assembler( + tiler_unsafe, + output=str(tmp_path / "b.zarr"), + source_position=synthetic_position, + chunks={"y": 32, "x": 64}, + ) + assert asm_unsafe.validate_parallel_safety() is False + + +def test_numpy_input_accepted(synthetic_position, tmp_path): + """append() accepts raw np.ndarray in addition to xr.DataArray.""" + original = synthetic_position.data[:] + xa = synthetic_position.to_xarray() + tiler = Tiler(xa, tile_size={"y": 64, "x": 128}) + asm = Assembler( + tiler, + output=str(tmp_path / "out.zarr"), + source_position=synthetic_position, + weights="uniform", + ) + for tile in tiler: + asm.append(tile, tile.to_xarray().values) + result = asm.get_output() + np.testing.assert_allclose(result.values, original, atol=1e-6) diff --git a/tests/tile/test_blend.py b/tests/tile/test_blend.py new file mode 100644 index 00000000..58269ec2 --- /dev/null +++ b/tests/tile/test_blend.py @@ -0,0 +1,218 @@ +"""Tests for _blend_tiles and apply_func_tiled.""" + +import dask.array +import numpy as np + +from iohub.tile import Tiler, apply_func_tiled, get_blender +from iohub.tile._blend import _blend_tiles + +# ---- _blend_tiles unit tests ---- + + +def test_blend_tiles_single_tile(synthetic_5d): + """Single tile returns the tile itself.""" + tiler = Tiler(synthetic_5d, tile_size={"y": 64, "x": 128}) + specs = list(tiler) + tiles = [s.to_xarray() for s in specs] + blender = get_blender("uniform") + result = _blend_tiles(tiles, specs, blender, tiler) + # Single tile — should be the same object + assert result is tiles[0] + + +def test_blend_tiles_no_overlap(synthetic_5d): + """Non-overlapping tiles: each cell has exactly 1 contributor.""" + tiler = Tiler(synthetic_5d, tile_size={"y": 32, "x": 64}, overlap={"y": 0, "x": 0}) + specs = list(tiler) + tiles = [s.to_xarray() for s in specs] + blender = get_blender("uniform") + result = _blend_tiles(tiles, specs, blender, tiler) + assert result.shape == synthetic_5d.shape + np.testing.assert_allclose(result.values, synthetic_5d.values, atol=1e-6) + + +def test_blend_tiles_uniform_identity(synthetic_5d): + """Uniform weights + identity fn = original data.""" + tiler = Tiler( + synthetic_5d, + tile_size={"y": 32, "x": 64}, + overlap={"y": 8, "x": 16}, + ) + specs = list(tiler) + tiles = [s.to_xarray() for s in specs] + blender = get_blender("uniform") + result = _blend_tiles(tiles, specs, blender, tiler) + assert result.shape == synthetic_5d.shape + np.testing.assert_allclose(result.values, synthetic_5d.values, atol=1e-5) + + +def test_blend_tiles_gaussian_identity(synthetic_5d): + """Gaussian weights + identity fn = original data.""" + tiler = Tiler( + synthetic_5d, + tile_size={"y": 32, "x": 64}, + overlap={"y": 8, "x": 16}, + ) + specs = list(tiler) + tiles = [s.to_xarray() for s in specs] + blender = get_blender("gaussian") + result = _blend_tiles(tiles, specs, blender, tiler) + assert result.shape == synthetic_5d.shape + np.testing.assert_allclose(result.values, synthetic_5d.values, atol=1e-5) + + +def test_blend_tiles_is_lazy(synthetic_5d): + """Result is dask-backed and overlap callback hasn't run yet.""" + call_count = 0 + + class CountingBlender: + """Blender that counts weight-computation calls.""" + + def weights(self, tile_shape, overlap, metadata=None): + nonlocal call_count + call_count += 1 + return np.ones(tile_shape, dtype=np.float64) + + tiler = Tiler( + synthetic_5d, + tile_size={"y": 32, "x": 64}, + overlap={"y": 8, "x": 16}, + ) + specs = list(tiler) + tiles = [s.to_xarray() for s in specs] + result = _blend_tiles(tiles, specs, CountingBlender(), tiler) + + # Graph built but no overlap callbacks have run yet + assert isinstance(result.data, dask.array.Array) + assert call_count == 0, ( + f"Blender.weights() called {call_count} times during graph construction — overlap regions should stay lazy" + ) + + # Now trigger computation + result.compute() + assert call_count > 0, "Blender.weights() should run during compute()" + + +def test_blend_tiles_length_mismatch(synthetic_5d): + """Mismatched tiles/specs raises ValueError.""" + tiler = Tiler(synthetic_5d, tile_size={"y": 32, "x": 64}) + specs = list(tiler) + tiles = [specs[0].to_xarray()] # wrong length + blender = get_blender("uniform") + try: + _blend_tiles(tiles, specs, blender, tiler) + assert False, "Expected ValueError" + except ValueError: + pass + + +# ---- apply_func_tiled tests ---- + + +def test_apply_func_tiled_identity_roundtrip(synthetic_position): + """apply_func_tiled with identity fn preserves data.""" + original = synthetic_position.data[:] + result = apply_func_tiled( + synthetic_position, + fn=lambda t: t, + tile_size={"y": 32, "x": 64}, + overlap={"y": 8, "x": 16}, + ) + assert isinstance(result.data, dask.array.Array) + assert result.shape == original.shape + np.testing.assert_allclose(result.values, original, atol=1e-5) + + +def test_apply_func_tiled_scaling(synthetic_position): + """apply_func_tiled correctly applies a scaling function.""" + original = synthetic_position.data[:] + result = apply_func_tiled( + synthetic_position, + fn=lambda t: t * 2, + tile_size={"y": 32, "x": 64}, + overlap={"y": 8, "x": 16}, + ) + np.testing.assert_allclose(result.values, original * 2, atol=1e-5) + + +def test_apply_func_tiled_numpy_return(synthetic_position): + """apply_func_tiled handles fn returning np.ndarray.""" + original = synthetic_position.data[:] + result = apply_func_tiled( + synthetic_position, + fn=lambda t: t.values * 3, + tile_size={"y": 32, "x": 64}, + overlap={"y": 8, "x": 16}, + ) + np.testing.assert_allclose(result.values, original * 3, atol=1e-5) + + +def test_apply_func_tiled_no_overlap(synthetic_position): + """apply_func_tiled works without overlap.""" + original = synthetic_position.data[:] + result = apply_func_tiled( + synthetic_position, + fn=lambda t: t, + tile_size={"y": 32, "x": 64}, + ) + assert result.shape == original.shape + np.testing.assert_allclose(result.values, original, atol=1e-6) + + +def test_apply_func_tiled_distance_blender(synthetic_position): + """apply_func_tiled with distance blender + identity = original.""" + original = synthetic_position.data[:] + result = apply_func_tiled( + synthetic_position, + fn=lambda t: t, + tile_size={"y": 32, "x": 64}, + overlap={"y": 8, "x": 16}, + weights="distance", + ) + np.testing.assert_allclose(result.values, original, atol=1e-5) + + +# --------------------------------------------------------------------------- +# ZYX tiling tests +# --------------------------------------------------------------------------- + + +def test_zyx_blend_tiles_uniform_identity(synthetic_5d_large_z): + """ZYX uniform blending with identity preserves data.""" + tiler = Tiler( + synthetic_5d_large_z, + tile_size={"z": 8, "y": 32, "x": 64}, + overlap={"z": 2, "y": 8, "x": 16}, + ) + specs = list(tiler) + tiles = [s.to_xarray() for s in specs] + blender = get_blender("uniform") + result = _blend_tiles(tiles, specs, blender, tiler) + assert result.shape == synthetic_5d_large_z.shape + np.testing.assert_allclose(result.values, synthetic_5d_large_z.values, atol=1e-5) + + +def test_zyx_apply_func_tiled_identity(synthetic_position_large_z): + """ZYX apply_func_tiled with identity fn preserves data.""" + original = synthetic_position_large_z.data[:] + result = apply_func_tiled( + synthetic_position_large_z, + fn=lambda t: t, + tile_size={"z": 8, "y": 32, "x": 64}, + overlap={"z": 2, "y": 8, "x": 16}, + ) + assert isinstance(result.data, dask.array.Array) + assert result.shape == original.shape + np.testing.assert_allclose(result.values, original, atol=1e-5) + + +def test_zyx_apply_func_tiled_scaling(synthetic_position_large_z): + """ZYX apply_func_tiled correctly applies a scaling function.""" + original = synthetic_position_large_z.data[:] + result = apply_func_tiled( + synthetic_position_large_z, + fn=lambda t: t * 3, + tile_size={"z": 4, "y": 32, "x": 64}, + overlap={"z": 1, "y": 8, "x": 16}, + ) + np.testing.assert_allclose(result.values, original * 3, atol=1e-5) diff --git a/tests/tile/test_blenders.py b/tests/tile/test_blenders.py new file mode 100644 index 00000000..0a872e9c --- /dev/null +++ b/tests/tile/test_blenders.py @@ -0,0 +1,83 @@ +"""Tests for blender protocol and built-in implementations.""" + +import numpy as np +import pytest +from hypothesis import given +from hypothesis import strategies as st + +from iohub.tile import GaussianBlender, UniformBlender, get_blender + + +@given( + h=st.integers(4, 128), + w=st.integers(4, 256), + blender=st.sampled_from([UniformBlender(), GaussianBlender()]), +) +def test_weight_shape_matches_input(h, w, blender): + """Blender output shape must equal the requested tile shape.""" + w_arr = blender.weights((h, w), {}) + assert w_arr.shape == (h, w) + assert w_arr.dtype == np.float64 + assert (w_arr > 0).all() + + +def test_gaussian_center_weighted(): + """Gaussian kernel has its maximum at the center.""" + g = GaussianBlender() + w = g.weights((32, 64), {}) + assert w[16, 32] >= w[0, 0] + assert w[16, 32] >= w[-1, -1] + + +def test_resolve_by_name_and_unknown(): + """get_blender resolves by name, passes through instances, rejects unknowns.""" + assert isinstance(get_blender("uniform"), UniformBlender) + assert isinstance(get_blender("gaussian"), GaussianBlender) + + b = UniformBlender() + assert get_blender(b) is b + + with pytest.raises(ValueError, match="Unknown blender"): + get_blender("nonexistent") + + +# --------------------------------------------------------------------------- +# 3D weight kernel tests +# --------------------------------------------------------------------------- + + +@given( + d=st.integers(4, 32), + h=st.integers(4, 64), + w=st.integers(4, 128), + blender=st.sampled_from([UniformBlender(), GaussianBlender()]), +) +def test_3d_weight_shape(d, h, w, blender): + """Blender output shape matches 3D tile shape.""" + w_arr = blender.weights((d, h, w), {}) + assert w_arr.shape == (d, h, w) + assert w_arr.dtype == np.float64 + assert (w_arr > 0).all() + + +def test_gaussian_3d_center_weighted(): + """3D Gaussian kernel has its maximum at the center.""" + g = GaussianBlender() + w = g.weights((8, 32, 64), {}) + assert w[4, 16, 32] >= w[0, 0, 0] + assert w[4, 16, 32] >= w[-1, -1, -1] + + +def test_3d_gaussian_separability(): + """3D Gaussian is the outer product of three 1D gaussians.""" + g = GaussianBlender() + # Use odd sizes so center is unambiguous + w3d = g.weights((9, 17, 33), {}) + # Each axis slice through center should be a 1D gaussian + center_z = w3d[:, 8, 16] + center_y = w3d[4, :, 16] + center_x = w3d[4, 8, :] + # Max should be at center for each + assert np.argmax(center_z) == 4 + assert np.argmax(center_y) == 8 + assert np.argmax(center_x) == 16 diff --git a/tests/tile/test_cache.py b/tests/tile/test_cache.py new file mode 100644 index 00000000..a21a1da1 --- /dev/null +++ b/tests/tile/test_cache.py @@ -0,0 +1,145 @@ +"""Tests for graph-informed overlap caching.""" + +import numpy as np + +from iohub.tile import Tiler, apply_func_tiled +from iohub.tile._cache import ( + _bfs_tile_order, + _estimate_overlap_bytes, + _overlap_regions, +) + + +def test_overlap_regions(synthetic_5d): + """Overlap regions match expected intersections.""" + tiler = Tiler( + synthetic_5d, + tile_size={"y": 32, "x": 64}, + overlap={"y": 8, "x": 16}, + ) + regions = _overlap_regions(tiler) + + assert len(regions) > 0 + + # Every region should have positive extent in all dims + for region in regions: + for dim, sl in region.items(): + assert sl.stop > sl.start + + # Regions should be within data bounds + for region in regions: + for dim, sl in region.items(): + assert sl.start >= 0 + assert sl.stop <= synthetic_5d.sizes[dim] + + +def test_overlap_regions_no_overlap(synthetic_5d): + """No overlap → no regions.""" + tiler = Tiler( + synthetic_5d, + tile_size={"y": 32, "x": 64}, + overlap={"y": 0, "x": 0}, + ) + regions = _overlap_regions(tiler) + assert len(regions) == 0 + + +def test_overlap_regions_deduplication(synthetic_5d): + """Regions with identical pixel ranges are deduplicated.""" + tiler = Tiler( + synthetic_5d, + tile_size={"y": 32, "x": 64}, + overlap={"y": 8, "x": 16}, + ) + regions = _overlap_regions(tiler) + keys = [tuple((d, s.start, s.stop) for d, s in r.items()) for r in regions] + assert len(keys) == len(set(keys)) + + +def test_bfs_order(synthetic_5d): + """BFS visits every tile and adjacent tiles are near each other.""" + tiler = Tiler( + synthetic_5d, + tile_size={"y": 32, "x": 64}, + overlap={"y": 8, "x": 16}, + ) + order = _bfs_tile_order(tiler) + + # Every tile visited exactly once + assert sorted(order) == list(range(len(tiler))) + + # Adjacent tiles (graph neighbors) should be closer in BFS order + # than in a random permutation. Just check they're all present. + assert len(order) == len(tiler) + + +def test_bfs_order_no_overlap(synthetic_5d): + """BFS on disconnected graph still returns all tiles.""" + tiler = Tiler( + synthetic_5d, + tile_size={"y": 32, "x": 64}, + overlap={"y": 0, "x": 0}, + ) + order = _bfs_tile_order(tiler) + assert sorted(order) == list(range(len(tiler))) + + +def test_estimate_overlap_bytes(synthetic_5d): + """Byte estimate is positive when overlap > 0.""" + tiler = Tiler( + synthetic_5d, + tile_size={"y": 32, "x": 64}, + overlap={"y": 8, "x": 16}, + ) + nbytes = _estimate_overlap_bytes(tiler) + assert nbytes > 0 + + # With no overlap, should be 0 + tiler_no = Tiler( + synthetic_5d, + tile_size={"y": 32, "x": 64}, + overlap={"y": 0, "x": 0}, + ) + assert _estimate_overlap_bytes(tiler_no) == 0 + + +def test_apply_func_tiled_persist_roundtrip(synthetic_position): + """apply_func_tiled with cache='persist' produces correct results.""" + original = synthetic_position.data[:] + result = apply_func_tiled( + synthetic_position, + fn=lambda t: t, + tile_size={"y": 32, "x": 64}, + overlap={"y": 8, "x": 16}, + cache="persist", + ) + assert result.shape == original.shape + np.testing.assert_allclose(result.values, original, atol=1e-5) + + +def test_apply_func_tiled_bfs_roundtrip(synthetic_position): + """apply_func_tiled with cache='bfs' produces correct results.""" + original = synthetic_position.data[:] + result = apply_func_tiled( + synthetic_position, + fn=lambda t: t, + tile_size={"y": 32, "x": 64}, + overlap={"y": 8, "x": 16}, + cache="bfs", + ) + assert result.shape == original.shape + np.testing.assert_allclose(result.values, original, atol=1e-5) + + +def test_apply_func_tiled_no_cache_default(synthetic_position): + """cache=None (default) works identically to before.""" + original = synthetic_position.data[:] + result = apply_func_tiled( + synthetic_position, + fn=lambda t: t, + tile_size={"y": 32, "x": 64}, + overlap={"y": 8, "x": 16}, + cache=None, + ) + assert result.shape == original.shape + np.testing.assert_allclose(result.values, original, atol=1e-5) diff --git a/tests/tile/test_composite.py b/tests/tile/test_composite.py new file mode 100644 index 00000000..1855dd95 --- /dev/null +++ b/tests/tile/test_composite.py @@ -0,0 +1,106 @@ +"""Tests for _composite_fovs sweep-line FOV compositing.""" + +import dask.array +import numpy as np +import xarray as xr + +from iohub.tile._composite import _composite_fovs +from iohub.tile._compositors import MaxCompositor, MeanCompositor + + +def _make_fov(data, y_offset=0.0, x_offset=0.0, pixel_size=1.0): + """Create a 5D FOV xr.DataArray with physical coords. + + Uses pixel_size=1.0 by default to avoid float alignment issues + in xr.concat (the compositors stack slices by coord values). + """ + t, c, z, h, w = data.shape + return xr.DataArray( + data, + dims=("t", "c", "z", "y", "x"), + coords={ + "t": np.arange(t), + "c": np.arange(c), + "z": np.arange(z, dtype=np.float64) * pixel_size, + "y": np.arange(h, dtype=np.float64) * pixel_size + y_offset, + "x": np.arange(w, dtype=np.float64) * pixel_size + x_offset, + }, + ) + + +def test_single_fov_passthrough(): + """Single FOV returns the same object.""" + data = np.ones((1, 1, 1, 8, 8), dtype=np.float32) + fov = _make_fov(data) + result = _composite_fovs([fov], MeanCompositor()) + assert result is fov + + +def test_two_overlapping_fovs_mean(): + """Two FOVs overlapping in X, MeanCompositor averages the overlap.""" + # FOV 0: x=[0..7], FOV 1: x=[4..11] → overlap at x=[4..7] + data_a = np.full((1, 1, 1, 8, 8), 2.0, dtype=np.float32) + data_b = np.full((1, 1, 1, 8, 8), 6.0, dtype=np.float32) + fov_a = _make_fov(data_a, x_offset=0.0) + fov_b = _make_fov(data_b, x_offset=4.0) + + result = _composite_fovs([fov_a, fov_b], MeanCompositor()) + + assert result.shape == (1, 1, 1, 8, 12) # 8 + 8 - 4 overlap = 12 wide + assert isinstance(result.data, dask.array.Array) + + vals = result.values + # Left region (FOV A only): columns 0..3 + np.testing.assert_allclose(vals[..., :4], 2.0) + # Overlap region: columns 4..7 → mean of 2 and 6 = 4 + np.testing.assert_allclose(vals[..., 4:8], 4.0) + # Right region (FOV B only): columns 8..11 + np.testing.assert_allclose(vals[..., 8:], 6.0) + + +def test_two_overlapping_fovs_max(): + """Two FOVs overlapping in X, MaxCompositor takes the max.""" + data_a = np.full((1, 1, 1, 8, 8), 2.0, dtype=np.float32) + data_b = np.full((1, 1, 1, 8, 8), 6.0, dtype=np.float32) + fov_a = _make_fov(data_a, x_offset=0.0) + fov_b = _make_fov(data_b, x_offset=4.0) + + result = _composite_fovs([fov_a, fov_b], MaxCompositor()) + + vals = result.values + np.testing.assert_allclose(vals[..., :4], 2.0) + np.testing.assert_allclose(vals[..., 4:8], 6.0) # max(2, 6) = 6 + np.testing.assert_allclose(vals[..., 8:], 6.0) + + +def test_non_overlapping_fovs_gap(): + """Two FOVs with a gap produce NaN in the gap region.""" + data_a = np.full((1, 1, 1, 4, 4), 1.0, dtype=np.float32) + data_b = np.full((1, 1, 1, 4, 4), 2.0, dtype=np.float32) + # FOV A at x=0..3, FOV B at x=8..11 → gap at x=4..7 + fov_a = _make_fov(data_a, x_offset=0.0) + fov_b = _make_fov(data_b, x_offset=8.0) + + result = _composite_fovs([fov_a, fov_b], MeanCompositor()) + + assert result.shape == (1, 1, 1, 4, 12) + vals = result.values + np.testing.assert_allclose(vals[..., :4], 1.0) + assert np.all(np.isnan(vals[..., 4:8])) + np.testing.assert_allclose(vals[..., 8:], 2.0) + + +def test_overlapping_fovs_y_direction(): + """Two FOVs overlapping in Y.""" + data_a = np.full((1, 1, 1, 8, 4), 10.0, dtype=np.float32) + data_b = np.full((1, 1, 1, 8, 4), 20.0, dtype=np.float32) + fov_a = _make_fov(data_a, y_offset=0.0) + fov_b = _make_fov(data_b, y_offset=4.0) + + result = _composite_fovs([fov_a, fov_b], MeanCompositor()) + + assert result.shape == (1, 1, 1, 12, 4) + vals = result.values + np.testing.assert_allclose(vals[..., :4, :], 10.0) + np.testing.assert_allclose(vals[..., 4:8, :], 15.0) # mean(10, 20) + np.testing.assert_allclose(vals[..., 8:, :], 20.0) diff --git a/tests/tile/test_slicer.py b/tests/tile/test_slicer.py new file mode 100644 index 00000000..9478d7f5 --- /dev/null +++ b/tests/tile/test_slicer.py @@ -0,0 +1,163 @@ +"""Tests for Tiler tile generation.""" + +import dask.array as da +import numpy as np +import pytest +import xarray as xr +from hypothesis import given +from hypothesis import strategies as st + +from iohub.tile import SamplingMode, Tile, Tiler +from tests.tile.conftest import tile_params, tile_params_zyx + + +@given( + params=tile_params(), + mode=st.sampled_from([SamplingMode.SQUEEZE, SamplingMode.EDGE]), +) +def test_tiles_cover_full_extent(synthetic_5d, params, mode): + """For SQUEEZE/EDGE modes, tiles must cover the full YX extent.""" + tile_size, overlap = params + tiler = Tiler(synthetic_5d, tile_size=tile_size, overlap=overlap, mode=mode) + + covered_y = np.zeros(synthetic_5d.sizes["y"], dtype=bool) + covered_x = np.zeros(synthetic_5d.sizes["x"], dtype=bool) + for tile in tiler: + covered_y[tile.slices["y"]] = True + covered_x[tile.slices["x"]] = True + assert covered_y.all(), f"Y not fully covered with {tile_size}, {overlap}" + assert covered_x.all(), f"X not fully covered with {tile_size}, {overlap}" + + from functools import reduce + from operator import mul + + assert len(tiler) == reduce(mul, tiler.tile_grid_shape, 1) + + +def test_single_tile_when_oversized(synthetic_5d): + """When tile_size > data, a single tile covering the full extent is returned.""" + tiler = Tiler(synthetic_5d, tile_size={"y": 999, "x": 999}) + assert len(tiler) == 1 + assert tiler[0].tile_shape == (64, 128) + + +@pytest.mark.parametrize( + "kwargs, match", + [ + ({"data_dims": ("a", "b"), "tile_size": {"y": 5, "x": 5}}, "'y' and 'x' dims"), + ({"tile_size": {"y": 32}}, "tile_size must specify"), + ({"tile_size": {"y": 32, "x": 64}, "overlap": {"y": 32, "x": 0}}, "overlap.*must be less"), + ], +) +def test_invalid_inputs(synthetic_5d, kwargs, match): + """Invalid dimensions, missing tile_size keys, or excessive overlap are rejected.""" + if "data_dims" in kwargs: + data = xr.DataArray(np.zeros((10, 10)), dims=kwargs.pop("data_dims")) + else: + data = synthetic_5d + with pytest.raises(ValueError, match=match): + Tiler(data, **kwargs) + + +def test_chunk_alignment_snaps_up(): + """align_to_chunks rounds tile_size up to chunk multiples.""" + dask_data = da.from_array(np.ones((1, 1, 2, 256, 512), dtype=np.float32), chunks=(1, 1, 2, 64, 128)) + data = xr.DataArray(dask_data, dims=("t", "c", "z", "y", "x")) + tiler = Tiler(data, tile_size={"y": 20, "x": 50}, align_to_chunks=True) + # 20 → 64, 50 → 128 + assert tiler._tile_size["y"] >= 64 + assert tiler._tile_size["x"] >= 128 + + +def test_iter_yields_correct_types(synthetic_5d): + """__iter__ yields Tiles, iter_xarrays yields DataArrays.""" + tiler = Tiler(synthetic_5d, tile_size={"y": 32, "x": 64}) + tiles = list(tiler) + xas = list(tiler.iter_xarrays()) + assert all(isinstance(t, Tile) for t in tiles) + assert all(isinstance(xa, xr.DataArray) for xa in xas) + assert len(tiles) == len(xas) + + +# --------------------------------------------------------------------------- +# ZYX tiling tests +# --------------------------------------------------------------------------- + + +def test_zyx_tiler_grid_shape(synthetic_5d_large_z): + """ZYX tiler produces a 3-tuple grid shape.""" + tiler = Tiler( + synthetic_5d_large_z, + tile_size={"z": 8, "y": 32, "x": 64}, + overlap={"z": 2, "y": 8, "x": 16}, + ) + assert len(tiler.tile_grid_shape) == 3 + assert tiler.tile_dims == ("z", "y", "x") + + from functools import reduce + from operator import mul + + assert len(tiler) == reduce(mul, tiler.tile_grid_shape, 1) + + +def test_zyx_tile_spec_properties(synthetic_5d_large_z): + """ZYX Tile has correct dims, shape, and bbox.""" + tiler = Tiler( + synthetic_5d_large_z, + tile_size={"z": 8, "y": 32, "x": 64}, + ) + tile = tiler[0] + assert tile.tile_dims == ("z", "y", "x") + assert len(tile.tile_shape) == 3 + assert tile.bbox.shape == (3, 2) + assert "z" in tile.slices + + +@given( + params=tile_params_zyx(), + mode=st.sampled_from([SamplingMode.SQUEEZE, SamplingMode.EDGE]), +) +def test_zyx_tiles_cover_full_extent(synthetic_5d_large_z, params, mode): + """For SQUEEZE/EDGE modes, ZYX tiles must cover all tiled dimensions.""" + tile_size, overlap = params + tiler = Tiler(synthetic_5d_large_z, tile_size=tile_size, overlap=overlap, mode=mode) + + for dim in ("z", "y", "x"): + covered = np.zeros(synthetic_5d_large_z.sizes[dim], dtype=bool) + for tile in tiler: + covered[tile.slices[dim]] = True + assert covered.all(), f"{dim} not fully covered with {tile_size}, {overlap}" + + +def test_zyx_neighborhood_graph(synthetic_5d_large_z): + """ZYX neighborhood graph has Z-direction edges.""" + tiler = Tiler( + synthetic_5d_large_z, + tile_size={"z": 8, "y": 64, "x": 128}, + overlap={"z": 2, "y": 0, "x": 0}, + ) + # With only Z overlap and single tile in YX, graph should have Z-edges + assert tiler.graph.number_of_edges() > 0 + assert len(tiler) > 1 + + +def test_zyx_to_xarray_slices_correctly(synthetic_5d_large_z): + """to_xarray() on ZYX tile slices Z, Y, and X dims correctly.""" + tiler = Tiler( + synthetic_5d_large_z, + tile_size={"z": 4, "y": 32, "x": 64}, + ) + tile = tiler[0] + xa = tile.to_xarray() + assert xa.sizes["z"] == 4 + assert xa.sizes["y"] == 32 + assert xa.sizes["x"] == 64 + # T and C should be unsliced + assert xa.sizes["t"] == synthetic_5d_large_z.sizes["t"] + assert xa.sizes["c"] == synthetic_5d_large_z.sizes["c"] + + +def test_zyx_invalid_dim_rejected(synthetic_5d): + """tile_size with a dim not in data raises ValueError.""" + with pytest.raises(ValueError, match="not found in data dims"): + Tiler(synthetic_5d, tile_size={"z": 2, "y": 32, "x": 64, "w": 10}) diff --git a/tests/tile/test_tile_and_assemble.py b/tests/tile/test_tile_and_assemble.py new file mode 100644 index 00000000..564dd2ca --- /dev/null +++ b/tests/tile/test_tile_and_assemble.py @@ -0,0 +1,200 @@ +"""Tests for the three-phase tile store API: create_tile_store, process_tiles, stitch_from_store.""" + +import numpy as np +import pytest + +from iohub.tile import create_tile_store, process_tiles, stitch_from_store + + +def _identity(t): + return t + + +def _scale2(t): + return t * 2 + + +def test_roundtrip_yx(synthetic_position, tmp_path): + """Three-phase identity round-trip preserves data (YX tiling).""" + original = synthetic_position.data[:] + store = str(tmp_path / "tiles.zarr") + output = str(tmp_path / "out.zarr") + + batches = create_tile_store( + synthetic_position, + tile_size={"y": 32, "x": 64}, + store=store, + overlap={"y": 8, "x": 16}, + ) + for batch in batches: + process_tiles(synthetic_position, _identity, store, batch) + + stitch_from_store(store, output, synthetic_position, weights="uniform") + + from iohub.ngff import open_ome_zarr + + result = open_ome_zarr(output, layout="fov").data[:] + assert result.shape == original.shape + np.testing.assert_allclose(result.astype(np.float32), original.astype(np.float32), atol=1e-4) + + +def test_roundtrip_gaussian(synthetic_position, tmp_path): + """Gaussian blending round-trip preserves identity.""" + original = synthetic_position.data[:] + store = str(tmp_path / "tiles.zarr") + output = str(tmp_path / "out.zarr") + + batches = create_tile_store( + synthetic_position, + tile_size={"y": 32, "x": 64}, + store=store, + overlap={"y": 8, "x": 16}, + ) + for batch in batches: + process_tiles(synthetic_position, _identity, store, batch) + + stitch_from_store(store, output, synthetic_position, weights="gaussian") + + from iohub.ngff import open_ome_zarr + + result = open_ome_zarr(output, layout="fov").data[:] + np.testing.assert_allclose(result.astype(np.float32), original.astype(np.float32), atol=1e-4) + + +def test_scaling(synthetic_position, tmp_path): + """process_tiles correctly applies a scaling function.""" + original = synthetic_position.data[:] + store = str(tmp_path / "tiles.zarr") + output = str(tmp_path / "out.zarr") + + batches = create_tile_store( + synthetic_position, + tile_size={"y": 32, "x": 64}, + store=store, + overlap={"y": 8, "x": 16}, + ) + for batch in batches: + process_tiles(synthetic_position, _scale2, store, batch) + + stitch_from_store(store, output, synthetic_position, weights="uniform") + + from iohub.ngff import open_ome_zarr + + result = open_ome_zarr(output, layout="fov").data[:] + np.testing.assert_allclose(result.astype(np.float32), (original * 2).astype(np.float32), atol=1e-4) + + +def test_create_tile_store_returns_batches(synthetic_position, tmp_path): + """create_tile_store returns correct batches.""" + store = str(tmp_path / "tiles.zarr") + batches = create_tile_store( + synthetic_position, + tile_size={"y": 32, "x": 64}, + store=store, + overlap={"y": 8, "x": 16}, + tile_batch_size=3, + ) + all_ids = [tid for batch in batches for tid in batch] + # All IDs present, no duplicates + assert sorted(all_ids) == list(range(len(all_ids))) + # Each batch ≤ batch_size + assert all(len(b) <= 3 for b in batches) + + +def test_store_already_exists_raises(synthetic_position, tmp_path): + """create_tile_store raises if store already exists.""" + store = str(tmp_path / "tiles.zarr") + create_tile_store(synthetic_position, tile_size={"y": 32, "x": 64}, store=store) + with pytest.raises(FileExistsError): + create_tile_store(synthetic_position, tile_size={"y": 32, "x": 64}, store=store) + + +def test_tile_ids_subset_parallel_pattern(synthetic_position, tmp_path): + """process_tiles with disjoint tile_id subsets (SLURM pattern) stitches correctly.""" + original = synthetic_position.data[:] + store = str(tmp_path / "tiles.zarr") + output = str(tmp_path / "out.zarr") + + batches = create_tile_store( + synthetic_position, + tile_size={"y": 32, "x": 64}, + store=store, + overlap={"y": 8, "x": 16}, + tile_batch_size=3, # small batches to simulate multiple jobs + ) + # Each batch processed independently (simulates separate SLURM jobs) + for batch in batches: + process_tiles(synthetic_position, _identity, store, batch) + + stitch_from_store(store, output, synthetic_position, weights="uniform") + + from iohub.ngff import open_ome_zarr + + result = open_ome_zarr(output, layout="fov").data[:] + np.testing.assert_allclose(result.astype(np.float32), original.astype(np.float32), atol=1e-4) + + +def test_stitch_raises_on_missing_tile(synthetic_position, tmp_path): + """stitch_from_store raises FileNotFoundError when a tile FOV is missing.""" + store = str(tmp_path / "tiles.zarr") + output = str(tmp_path / "out.zarr") + + batches = create_tile_store( + synthetic_position, + tile_size={"y": 32, "x": 64}, + store=store, + overlap={"y": 8, "x": 16}, + tile_batch_size=2, # ensure multiple batches + ) + assert len(batches) > 1, "Need multiple batches for this test" + # Only process the first batch — leave the rest missing + process_tiles(synthetic_position, _identity, store, batches[0]) + + with pytest.raises(FileNotFoundError, match="Tile"): + stitch_from_store(store, output, synthetic_position) + + +def test_process_tiles_wrong_shape_raises(synthetic_position, tmp_path): + """process_tiles raises if fn returns wrong spatial shape.""" + store = str(tmp_path / "tiles.zarr") + create_tile_store( + synthetic_position, + tile_size={"y": 32, "x": 64}, + store=store, + ) + + # fn that changes spatial shape + def bad_fn(t): + return t.values[:, :, :, :16, :16] # crop to wrong size + + with pytest.raises(ValueError, match="preserve spatial dimensions"): + process_tiles(synthetic_position, bad_fn, store, [0]) + + +# --------------------------------------------------------------------------- +# ZYX tiling tests +# --------------------------------------------------------------------------- + + +def test_zyx_roundtrip(synthetic_position_large_z, tmp_path): + """Three-phase ZYX identity round-trip preserves data.""" + original = synthetic_position_large_z.data[:] + store = str(tmp_path / "tiles_zyx.zarr") + output = str(tmp_path / "out_zyx.zarr") + + batches = create_tile_store( + synthetic_position_large_z, + tile_size={"z": 8, "y": 32, "x": 64}, + store=store, + overlap={"z": 2, "y": 8, "x": 16}, + ) + for batch in batches: + process_tiles(synthetic_position_large_z, _identity, store, batch) + + stitch_from_store(store, output, synthetic_position_large_z, weights="uniform") + + from iohub.ngff import open_ome_zarr + + result = open_ome_zarr(output, layout="fov").data[:] + assert result.shape == original.shape + np.testing.assert_allclose(result.astype(np.float32), original.astype(np.float32), atol=1e-4) diff --git a/uv.lock b/uv.lock index 34984fe1..2d9bed23 100644 --- a/uv.lock +++ b/uv.lock @@ -783,6 +783,8 @@ dependencies = [ { name = "pydantic" }, { name = "pydantic-extra-types" }, { name = "rich" }, + { name = "scipy" }, + { name = "submitit" }, { name = "tifffile" }, { name = "tqdm" }, { name = "xarray" }, @@ -839,6 +841,8 @@ requires-dist = [ { name = "pydantic", specifier = ">=2.8.2" }, { name = "pydantic-extra-types", specifier = ">=2.9.0" }, { name = "rich" }, + { name = "scipy", specifier = ">=1.17.0" }, + { name = "submitit", specifier = ">=1.5.4" }, { name = "tensorstore", marker = "extra == 'tensorstore'", specifier = ">=0.1.64" }, { name = "tifffile", specifier = ">=2025.5.21" }, { name = "tqdm" }, @@ -2645,6 +2649,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0", size = 29575, upload-time = "2021-05-16T22:03:41.177Z" }, ] +[[package]] +name = "submitit" +version = "1.5.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cloudpickle" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/47/86/497018fb3b74e71bef45df82762b176e6b3d159f29941c20d2f141ec4096/submitit-1.5.4.tar.gz", hash = "sha256:7100848bd1cdda79c7196e54ee830793ae75fd7adde0c5bef738d72360a07508", size = 81538, upload-time = "2025-12-17T19:20:03.396Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ea/bb/711e1c2ebd18a21202c972dd5d5c8e09a921f2d3560e3a53d6350c808ab7/submitit-1.5.4-py3-none-any.whl", hash = "sha256:c26f3a7c8d4150eaf70b1da71e2023e9e9936c93e8342ed7db910f29158561c5", size = 76043, upload-time = "2025-12-17T19:20:01.941Z" }, +] + [[package]] name = "tensorstore" version = "0.1.80"