1515from typing import Generator , Literal , Sequence , Type , TypeAlias , overload
1616
1717import numpy as np
18+ import xarray as xr
1819import zarr .codecs
1920from numpy .typing import ArrayLike , DTypeLike , NDArray
2021from pydantic import ValidationError
@@ -984,18 +985,12 @@ def initialize_pyramid(self, levels: int) -> None:
984985 chunks = _pad_shape (_scale_integers (array .chunks , factor ), len (shape ))
985986
986987 if array .shards is not None :
987- shards = array .shards [:- 3 ] + _scale_integers (
988- array .shards [- 3 :], factor
989- )
988+ shards = array .shards [:- 3 ] + _scale_integers (array .shards [- 3 :], factor )
990989 shards_ratio = tuple (s // c for c , s in zip (chunks , shards ))
991990 else :
992991 shards_ratio = None
993992
994- transforms = deepcopy (
995- self .metadata .multiscales [0 ]
996- .datasets [0 ]
997- .coordinate_transformations
998- )
993+ transforms = deepcopy (self .metadata .multiscales [0 ].datasets [0 ].coordinate_transformations )
999994 for tr in transforms :
1000995 if tr .type == "scale" :
1001996 for i in range (len (tr .scale ))[- 3 :]:
@@ -1051,16 +1046,12 @@ def compute_pyramid(
10511046
10521047 num_arrays = len (self .array_keys ())
10531048 if num_arrays == 0 :
1054- raise ValueError (
1055- "No level 0 array exists. Create base array before computing "
1056- "pyramid."
1057- )
1049+ raise ValueError ("No level 0 array exists. Create base array before computing pyramid." )
10581050
10591051 if levels is None :
10601052 if num_arrays == 1 :
10611053 raise ValueError (
1062- "Pyramid structure doesn't exist and levels=None. "
1063- "Specify 'levels' parameter to create pyramid."
1054+ "Pyramid structure doesn't exist and levels=None. Specify 'levels' parameter to create pyramid."
10641055 )
10651056 levels = num_arrays
10661057
@@ -1084,10 +1075,7 @@ def compute_pyramid(
10841075 current_scale = self .get_effective_scale (str (level ))
10851076 previous_scale = self .get_effective_scale (str (level - 1 ))
10861077
1087- downsample_factors = [
1088- int (round (current_scale [i ] / previous_scale [i ]))
1089- for i in range (len (current_scale ))
1090- ]
1078+ downsample_factors = [int (round (current_scale [i ] / previous_scale [i ])) for i in range (len (current_scale ))]
10911079
10921080 _downsample_tensorstore (
10931081 source_ts = previous_ts ,
@@ -1328,6 +1316,159 @@ def set_contrast_limits(self, channel_name: str, window: WindowDict):
13281316 self .metadata .omero .channels [channel_index ].window = window
13291317 self .dump_meta ()
13301318
1319+ def to_xarray (self ) -> xr .DataArray :
1320+ """Export full Position data as a labeled xarray.DataArray (tczyx).
1321+
1322+ The DataArray is backed by a dask array (lazy, no data loaded
1323+ until ``.values`` or ``.compute()`` is called).
1324+
1325+ Coordinate units follow CF conventions: each coordinate carries
1326+ its own ``attrs["units"]`` (e.g. ``xa.coords["z"].attrs["units"]
1327+ == "micrometer"``). ``xa.attrs`` is reserved for value-level
1328+ metadata (e.g. ``xa.attrs["units"] = "nanometer"`` for ret).
1329+
1330+ Returns
1331+ -------
1332+ xr.DataArray
1333+ 5D labeled array with coordinates derived from
1334+ channel names and physical scales/units.
1335+ """
1336+ all_channel_names = self .channel_names
1337+ scale = self .scale
1338+ translation = self .get_effective_translation (self .metadata .multiscales [0 ].datasets [0 ].path )
1339+
1340+ data = self .data .dask_array ()
1341+ T , C , Z , Y , X = data .shape
1342+
1343+ # Build axis unit lookup from OME metadata
1344+ axis_units = {}
1345+ for axis in self .axes :
1346+ unit = getattr (axis , "unit" , None )
1347+ if unit is not None :
1348+ axis_units [axis .name .lower ()] = unit
1349+
1350+ # CF convention: units live in per-coordinate attrs
1351+ physical = {"t" : (T , 0 ), "z" : (Z , 2 ), "y" : (Y , 3 ), "x" : (X , 4 )}
1352+ coords = {"c" : ("c" , all_channel_names )}
1353+ for dim , (size , idx ) in physical .items ():
1354+ values = np .arange (size ) * scale [idx ] + translation [idx ]
1355+ attrs = {"units" : axis_units [dim ]} if dim in axis_units else {}
1356+ coords [dim ] = (dim , values , attrs )
1357+
1358+ # Restore any previously saved DataArray attrs from zarr
1359+ iohub_dict = self .zattrs .get ("iohub" , {})
1360+ saved_attrs = dict (iohub_dict .get ("xarray_attrs" , {}))
1361+
1362+ return xr .DataArray (
1363+ data ,
1364+ dims = ("t" , "c" , "z" , "y" , "x" ),
1365+ coords = coords ,
1366+ attrs = saved_attrs ,
1367+ )
1368+
1369+ def write_xarray (self , data_array : xr .DataArray , image : str = "0" ) -> None :
1370+ """Write an xarray.DataArray into this Position.
1371+
1372+ Supports writing a subset of channels and/or timepoints.
1373+ The image array is created on first call; subsequent calls
1374+ write into the existing array at the correct indices.
1375+
1376+ Scales, translations, axis units, and DataArray attrs are
1377+ set from the first write and updated on subsequent writes.
1378+
1379+ Parameters
1380+ ----------
1381+ data_array : xr.DataArray
1382+ 5D labeled array with tczyx dimensions.
1383+ The "c" coordinate must be a subset of this Position's
1384+ channel names. "t" coordinates are mapped to time indices
1385+ via the scale and translation.
1386+ image : str, optional
1387+ Name of the image array to write to, by default "0".
1388+ """
1389+ if tuple (data_array .dims ) != ("t" , "c" , "z" , "y" , "x" ):
1390+ raise ValueError (f"DataArray dims must be ('t', 'c', 'z', 'y', 'x'), got { data_array .dims } " )
1391+
1392+ # Validate channels are a subset
1393+ xa_channels = list (data_array .coords ["c" ].values )
1394+ for ch in xa_channels :
1395+ if ch not in self .channel_names :
1396+ raise ValueError (f"Channel '{ ch } ' not in this Position's channel names { self .channel_names } " )
1397+
1398+ # Infer scales and translations from coordinates
1399+ def _coord_scale (coord_values ):
1400+ if len (coord_values ) < 2 :
1401+ return 1.0
1402+ return float (coord_values [1 ] - coord_values [0 ])
1403+
1404+ t_scale = _coord_scale (data_array .coords ["t" ].values )
1405+ z_scale = _coord_scale (data_array .coords ["z" ].values )
1406+ y_scale = _coord_scale (data_array .coords ["y" ].values )
1407+ x_scale = _coord_scale (data_array .coords ["x" ].values )
1408+
1409+ t_trans = float (data_array .coords ["t" ].values [0 ])
1410+ z_trans = float (data_array .coords ["z" ].values [0 ])
1411+ y_trans = float (data_array .coords ["y" ].values [0 ])
1412+ x_trans = float (data_array .coords ["x" ].values [0 ])
1413+
1414+ # Read coordinate units from per-coordinate attrs (CF convention)
1415+ def _coord_unit (dim , default ):
1416+ return data_array .coords [dim ].attrs .get ("units" , default )
1417+
1418+ self .axes = [
1419+ TimeAxisMeta (name = "T" , unit = _coord_unit ("t" , "second" )),
1420+ ChannelAxisMeta (name = "C" ),
1421+ SpaceAxisMeta (name = "Z" , unit = _coord_unit ("z" , "micrometer" )),
1422+ SpaceAxisMeta (name = "Y" , unit = _coord_unit ("y" , "micrometer" )),
1423+ SpaceAxisMeta (name = "X" , unit = _coord_unit ("x" , "micrometer" )),
1424+ ]
1425+
1426+ transforms = [
1427+ TransformationMeta (
1428+ type = "scale" ,
1429+ scale = [t_scale , 1.0 , z_scale , y_scale , x_scale ],
1430+ )
1431+ ]
1432+ if any (v != 0.0 for v in [t_trans , z_trans , y_trans , x_trans ]):
1433+ transforms .append (
1434+ TransformationMeta (
1435+ type = "translation" ,
1436+ translation = [t_trans , 0.0 , z_trans , y_trans , x_trans ],
1437+ )
1438+ )
1439+
1440+ np_data = data_array .values
1441+
1442+ # Create image array if it doesn't exist yet
1443+ if image not in self :
1444+ T_full = len (data_array .coords ["t" ])
1445+ _ , _ , Z , Y , X = np_data .shape
1446+ full_shape = (T_full , len (self .channel_names ), Z , Y , X )
1447+ self .create_zeros (
1448+ image ,
1449+ shape = full_shape ,
1450+ dtype = np_data .dtype ,
1451+ transform = transforms ,
1452+ )
1453+
1454+ # Map channel names to indices
1455+ c_indices = [self .get_channel_index (ch ) for ch in xa_channels ]
1456+
1457+ # Map T coordinates to indices using scale and translation
1458+ scale = self .get_effective_scale (image )
1459+ translation = self .get_effective_translation (image )
1460+ t_coords = data_array .coords ["t" ].values
1461+ t_indices = np .round ((t_coords - translation [0 ]) / scale [0 ]).astype (int )
1462+
1463+ arr = self [image ]
1464+ arr .oindex [t_indices , c_indices ] = np_data
1465+
1466+ # Persist DataArray attrs to zarr for round-tripping
1467+ if data_array .attrs :
1468+ iohub_dict = dict (self .zattrs .get ("iohub" , {}))
1469+ iohub_dict ["xarray_attrs" ] = dict (data_array .attrs )
1470+ self .zattrs ["iohub" ] = iohub_dict
1471+
13311472
13321473class TiledPosition (Position ):
13331474 """Variant of the NGFF position node
0 commit comments