Skip to content

Commit b0684bb

Browse files
Add xarray integration to Position (#370)
* Add Position.to_xarray() method * Add Position.write_xarray() method * Add tests for to_xarray and write_xarray * better channel names handling * lowercase dimension names * nest xarray atributes under "iohub" key * rng fixture * revise tests to use hypothesis * use oindex instead of double loop
1 parent 3c6ae7d commit b0684bb

3 files changed

Lines changed: 649 additions & 18 deletions

File tree

iohub/ngff/nodes.py

Lines changed: 159 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import Generator, Literal, Sequence, Type, TypeAlias, overload
1616

1717
import numpy as np
18+
import xarray as xr
1819
import zarr.codecs
1920
from numpy.typing import ArrayLike, DTypeLike, NDArray
2021
from pydantic import ValidationError
@@ -984,18 +985,12 @@ def initialize_pyramid(self, levels: int) -> None:
984985
chunks = _pad_shape(_scale_integers(array.chunks, factor), len(shape))
985986

986987
if array.shards is not None:
987-
shards = array.shards[:-3] + _scale_integers(
988-
array.shards[-3:], factor
989-
)
988+
shards = array.shards[:-3] + _scale_integers(array.shards[-3:], factor)
990989
shards_ratio = tuple(s // c for c, s in zip(chunks, shards))
991990
else:
992991
shards_ratio = None
993992

994-
transforms = deepcopy(
995-
self.metadata.multiscales[0]
996-
.datasets[0]
997-
.coordinate_transformations
998-
)
993+
transforms = deepcopy(self.metadata.multiscales[0].datasets[0].coordinate_transformations)
999994
for tr in transforms:
1000995
if tr.type == "scale":
1001996
for i in range(len(tr.scale))[-3:]:
@@ -1051,16 +1046,12 @@ def compute_pyramid(
10511046

10521047
num_arrays = len(self.array_keys())
10531048
if num_arrays == 0:
1054-
raise ValueError(
1055-
"No level 0 array exists. Create base array before computing "
1056-
"pyramid."
1057-
)
1049+
raise ValueError("No level 0 array exists. Create base array before computing pyramid.")
10581050

10591051
if levels is None:
10601052
if num_arrays == 1:
10611053
raise ValueError(
1062-
"Pyramid structure doesn't exist and levels=None. "
1063-
"Specify 'levels' parameter to create pyramid."
1054+
"Pyramid structure doesn't exist and levels=None. Specify 'levels' parameter to create pyramid."
10641055
)
10651056
levels = num_arrays
10661057

@@ -1084,10 +1075,7 @@ def compute_pyramid(
10841075
current_scale = self.get_effective_scale(str(level))
10851076
previous_scale = self.get_effective_scale(str(level - 1))
10861077

1087-
downsample_factors = [
1088-
int(round(current_scale[i] / previous_scale[i]))
1089-
for i in range(len(current_scale))
1090-
]
1078+
downsample_factors = [int(round(current_scale[i] / previous_scale[i])) for i in range(len(current_scale))]
10911079

10921080
_downsample_tensorstore(
10931081
source_ts=previous_ts,
@@ -1328,6 +1316,159 @@ def set_contrast_limits(self, channel_name: str, window: WindowDict):
13281316
self.metadata.omero.channels[channel_index].window = window
13291317
self.dump_meta()
13301318

1319+
def to_xarray(self) -> xr.DataArray:
1320+
"""Export full Position data as a labeled xarray.DataArray (tczyx).
1321+
1322+
The DataArray is backed by a dask array (lazy, no data loaded
1323+
until ``.values`` or ``.compute()`` is called).
1324+
1325+
Coordinate units follow CF conventions: each coordinate carries
1326+
its own ``attrs["units"]`` (e.g. ``xa.coords["z"].attrs["units"]
1327+
== "micrometer"``). ``xa.attrs`` is reserved for value-level
1328+
metadata (e.g. ``xa.attrs["units"] = "nanometer"`` for ret).
1329+
1330+
Returns
1331+
-------
1332+
xr.DataArray
1333+
5D labeled array with coordinates derived from
1334+
channel names and physical scales/units.
1335+
"""
1336+
all_channel_names = self.channel_names
1337+
scale = self.scale
1338+
translation = self.get_effective_translation(self.metadata.multiscales[0].datasets[0].path)
1339+
1340+
data = self.data.dask_array()
1341+
T, C, Z, Y, X = data.shape
1342+
1343+
# Build axis unit lookup from OME metadata
1344+
axis_units = {}
1345+
for axis in self.axes:
1346+
unit = getattr(axis, "unit", None)
1347+
if unit is not None:
1348+
axis_units[axis.name.lower()] = unit
1349+
1350+
# CF convention: units live in per-coordinate attrs
1351+
physical = {"t": (T, 0), "z": (Z, 2), "y": (Y, 3), "x": (X, 4)}
1352+
coords = {"c": ("c", all_channel_names)}
1353+
for dim, (size, idx) in physical.items():
1354+
values = np.arange(size) * scale[idx] + translation[idx]
1355+
attrs = {"units": axis_units[dim]} if dim in axis_units else {}
1356+
coords[dim] = (dim, values, attrs)
1357+
1358+
# Restore any previously saved DataArray attrs from zarr
1359+
iohub_dict = self.zattrs.get("iohub", {})
1360+
saved_attrs = dict(iohub_dict.get("xarray_attrs", {}))
1361+
1362+
return xr.DataArray(
1363+
data,
1364+
dims=("t", "c", "z", "y", "x"),
1365+
coords=coords,
1366+
attrs=saved_attrs,
1367+
)
1368+
1369+
def write_xarray(self, data_array: xr.DataArray, image: str = "0") -> None:
1370+
"""Write an xarray.DataArray into this Position.
1371+
1372+
Supports writing a subset of channels and/or timepoints.
1373+
The image array is created on first call; subsequent calls
1374+
write into the existing array at the correct indices.
1375+
1376+
Scales, translations, axis units, and DataArray attrs are
1377+
set from the first write and updated on subsequent writes.
1378+
1379+
Parameters
1380+
----------
1381+
data_array : xr.DataArray
1382+
5D labeled array with tczyx dimensions.
1383+
The "c" coordinate must be a subset of this Position's
1384+
channel names. "t" coordinates are mapped to time indices
1385+
via the scale and translation.
1386+
image : str, optional
1387+
Name of the image array to write to, by default "0".
1388+
"""
1389+
if tuple(data_array.dims) != ("t", "c", "z", "y", "x"):
1390+
raise ValueError(f"DataArray dims must be ('t', 'c', 'z', 'y', 'x'), got {data_array.dims}")
1391+
1392+
# Validate channels are a subset
1393+
xa_channels = list(data_array.coords["c"].values)
1394+
for ch in xa_channels:
1395+
if ch not in self.channel_names:
1396+
raise ValueError(f"Channel '{ch}' not in this Position's channel names {self.channel_names}")
1397+
1398+
# Infer scales and translations from coordinates
1399+
def _coord_scale(coord_values):
1400+
if len(coord_values) < 2:
1401+
return 1.0
1402+
return float(coord_values[1] - coord_values[0])
1403+
1404+
t_scale = _coord_scale(data_array.coords["t"].values)
1405+
z_scale = _coord_scale(data_array.coords["z"].values)
1406+
y_scale = _coord_scale(data_array.coords["y"].values)
1407+
x_scale = _coord_scale(data_array.coords["x"].values)
1408+
1409+
t_trans = float(data_array.coords["t"].values[0])
1410+
z_trans = float(data_array.coords["z"].values[0])
1411+
y_trans = float(data_array.coords["y"].values[0])
1412+
x_trans = float(data_array.coords["x"].values[0])
1413+
1414+
# Read coordinate units from per-coordinate attrs (CF convention)
1415+
def _coord_unit(dim, default):
1416+
return data_array.coords[dim].attrs.get("units", default)
1417+
1418+
self.axes = [
1419+
TimeAxisMeta(name="T", unit=_coord_unit("t", "second")),
1420+
ChannelAxisMeta(name="C"),
1421+
SpaceAxisMeta(name="Z", unit=_coord_unit("z", "micrometer")),
1422+
SpaceAxisMeta(name="Y", unit=_coord_unit("y", "micrometer")),
1423+
SpaceAxisMeta(name="X", unit=_coord_unit("x", "micrometer")),
1424+
]
1425+
1426+
transforms = [
1427+
TransformationMeta(
1428+
type="scale",
1429+
scale=[t_scale, 1.0, z_scale, y_scale, x_scale],
1430+
)
1431+
]
1432+
if any(v != 0.0 for v in [t_trans, z_trans, y_trans, x_trans]):
1433+
transforms.append(
1434+
TransformationMeta(
1435+
type="translation",
1436+
translation=[t_trans, 0.0, z_trans, y_trans, x_trans],
1437+
)
1438+
)
1439+
1440+
np_data = data_array.values
1441+
1442+
# Create image array if it doesn't exist yet
1443+
if image not in self:
1444+
T_full = len(data_array.coords["t"])
1445+
_, _, Z, Y, X = np_data.shape
1446+
full_shape = (T_full, len(self.channel_names), Z, Y, X)
1447+
self.create_zeros(
1448+
image,
1449+
shape=full_shape,
1450+
dtype=np_data.dtype,
1451+
transform=transforms,
1452+
)
1453+
1454+
# Map channel names to indices
1455+
c_indices = [self.get_channel_index(ch) for ch in xa_channels]
1456+
1457+
# Map T coordinates to indices using scale and translation
1458+
scale = self.get_effective_scale(image)
1459+
translation = self.get_effective_translation(image)
1460+
t_coords = data_array.coords["t"].values
1461+
t_indices = np.round((t_coords - translation[0]) / scale[0]).astype(int)
1462+
1463+
arr = self[image]
1464+
arr.oindex[t_indices, c_indices] = np_data
1465+
1466+
# Persist DataArray attrs to zarr for round-tripping
1467+
if data_array.attrs:
1468+
iohub_dict = dict(self.zattrs.get("iohub", {}))
1469+
iohub_dict["xarray_attrs"] = dict(data_array.attrs)
1470+
self.zattrs["iohub"] = iohub_dict
1471+
13311472

13321473
class TiledPosition(Position):
13331474
"""Variant of the NGFF position node

tests/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
from wget import download
1010

1111

12+
@pytest.fixture
13+
def rng():
14+
return np.random.default_rng(42)
15+
16+
1217
def _download_ndtiff_v3_labeled_positions(test_data: Path) -> None:
1318
ghfs = fsspec.filesystem(
1419
"github",

0 commit comments

Comments
 (0)