diff --git a/iohub/cli/cli.py b/iohub/cli/cli.py index 9dad5657..d44b9b57 100644 --- a/iohub/cli/cli.py +++ b/iohub/cli/cli.py @@ -123,12 +123,18 @@ def convert(input, output, grid_layout, chunks): type=float, help="New x scale", ) +@click.option( + "--image", + required=False, + help="Image name to set scale for. Default is '0'", +) def set_scale( input_position_dirpaths, t_scale=None, z_scale=None, y_scale=None, x_scale=None, + image=None, ): """Update scale metadata in OME-Zarr datasets. @@ -138,6 +144,8 @@ def set_scale( >> iohub set-scale -i input.zarr/*/*/* -z 2.0 """ + if image is None: + image = "0" for input_position_dirpath in input_position_dirpaths: with open_ome_zarr( input_position_dirpath, layout="fov", mode="r+" @@ -147,12 +155,7 @@ def set_scale( ): if value is None: continue - old_value = dataset.scale[dataset.get_axis_index(name)] - print( - f"Updating {input_position_dirpath} {name} scale from " - f"{old_value} to {value}." - ) - dataset.set_scale("0", name, value) + dataset.set_scale(image, name, value) @cli.command(name="rename-wells") diff --git a/iohub/ngff/nodes.py b/iohub/ngff/nodes.py index f9b515de..02d36e2a 100644 --- a/iohub/ngff/nodes.py +++ b/iohub/ngff/nodes.py @@ -9,6 +9,7 @@ import math import os from copy import deepcopy +from datetime import datetime from pathlib import Path from typing import TYPE_CHECKING, Generator, Literal, Sequence, Type @@ -1144,17 +1145,14 @@ def set_transform( self.dump_meta() def set_scale( - self, - image: str | Literal["*"], - axis_name: str, - new_scale: float, + self, image: str | Literal["*"], axis_name: str, new_scale: float ): """Set the scale for a named axis. Either one image array or the whole FOV. Parameters ---------- - image : str | Literal[ + image : str | Literal['*'] Name of one image array (e.g. "0") to transform, or "*" for the whole FOV axis_name : str @@ -1162,39 +1160,56 @@ def set_scale( new_scale : float Value of the new scale. """ - if len(self.metadata.multiscales) > 1: - raise NotImplementedError( - "Cannot set scale for multi-resolution images." - ) - if new_scale <= 0: - raise ValueError("New scale must be positive.") - + raise ValueError( + f"New scale {axis_name}: {new_scale} is not positive!" + ) + if image not in self and image != "*": + raise KeyError(f"Image {image} not found.") axis_index = self.get_axis_index(axis_name) - + # Update scale while preserving existing transforms + if image == "*": + transforms = ( + self.metadata.multiscales[0].coordinate_transformations or [] + ) + else: + for dataset_meta in self.metadata.multiscales[0].datasets: + if dataset_meta.path == image: + transforms = dataset_meta.coordinate_transformations + break # Append old scale to metadata iohub_dict = {} if "iohub" in self.zattrs: iohub_dict = self.zattrs["iohub"] - iohub_dict.update({f"prior_{axis_name}_scale": self.scale[axis_index]}) - self.zattrs["iohub"] = iohub_dict - - # Update scale while preserving existing transforms - transforms = ( - self.metadata.multiscales[0].datasets[0].coordinate_transformations + if "previous_transforms" not in iohub_dict: + iohub_dict["previous_transforms"] = [] + iohub_dict["previous_transforms"].append( + { + "image": image, + "transforms": [ + t.model_dump(**TO_DICT_SETTINGS) for t in transforms + ], + "modified": datetime.now().isoformat(), + } ) + self.zattrs["iohub"] = iohub_dict # Replace default identity transform with scale - if len(transforms) == 1 and transforms[0].type == "identity": + if transforms == [TransformationMeta(type="identity")]: transforms = [TransformationMeta(type="scale", scale=[1] * 5)] # Add scale transform if not present if not any([transform.type == "scale" for transform in transforms]): transforms.append(TransformationMeta(type="scale", scale=[1] * 5)) - + new_transforms = [] for transform in transforms: if transform.type == "scale": + old_scale = transform.scale[axis_index] transform.scale[axis_index] = new_scale - - self.set_transform(image, transforms) + new_transforms.append(transform) + _logger.info( + f"Updating scale for axis {axis_name} " + f"from {old_scale} to {new_scale}." + ) + self.set_transform(image, new_transforms) def set_contrast_limits(self, channel_name: str, window: WindowDict): """Set the contrast limits for a channel. diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index f010e58d..2b58420b 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -121,7 +121,7 @@ def test_cli_convert_ome_tiff(grid_layout, tmpdir): assert "Converting" in result.output -def test_cli_set_scale(): +def test_cli_set_scale(caplog): with _temp_copy(hcs_ref) as store_path: store_path = Path(store_path) position_path = Path(store_path) / "B" / "03" / "0" @@ -149,23 +149,16 @@ def test_cli_set_scale(): ], ) assert result_pos.exit_code == 0 - assert "Updating" in result_pos.output - + assert any("Updating" in record.message for record in caplog.records) with open_ome_zarr(position_path, layout="fov") as output_dataset: assert tuple(output_dataset.scale[-3:]) == (random_z, 0.5, 0.5) assert output_dataset.scale != old_scale - assert ( - output_dataset.zattrs["iohub"]["prior_x_scale"] - == old_scale[-1] - ) - assert ( - output_dataset.zattrs["iohub"]["prior_y_scale"] - == old_scale[-2] - ) - assert ( - output_dataset.zattrs["iohub"]["prior_z_scale"] - == old_scale[-3] - ) + for i, record in enumerate( + output_dataset.zattrs["iohub"]["previous_transforms"] + ): + for transform in record["transforms"]: + if transform["type"] == "scale": + assert transform["scale"][-3:][i] == old_scale[-3:][i] # Test plate-expands-into-positions behavior runner = CliRunner() @@ -181,7 +174,11 @@ def test_cli_set_scale(): ) with open_ome_zarr(position_path, layout="fov") as output_dataset: assert output_dataset.scale[-1] == 0.1 - assert output_dataset.zattrs["iohub"]["prior_x_scale"] == 0.5 + for transform in output_dataset.zattrs["iohub"][ + "previous_transforms" + ][-1]["transforms"]: + if transform["type"] == "scale": + assert transform["scale"][-1] == 0.5 def test_cli_rename_wells_help(): diff --git a/tests/ngff/test_ngff.py b/tests/ngff/test_ngff.py index e34ee255..595a59fb 100644 --- a/tests/ngff/test_ngff.py +++ b/tests/ngff/test_ngff.py @@ -642,38 +642,51 @@ def test_set_transform_fov(ch_shape_dtype, arr_name): ] -@given( - ch_shape_dtype=_channels_and_random_5d_shape_and_dtype(), -) -@settings(deadline=None) -def test_set_scale(ch_shape_dtype): +@pytest.mark.parametrize("image_name", ["0", "1", "a", "*"]) +def test_set_scale(image_name): """Test `iohub.ngff.Position.set_scale()`""" - channel_names, shape, dtype = ch_shape_dtype - transform = [ - TransformationMeta(type="translation", translation=(1, 2, 3, 4, 5)), - TransformationMeta(type="scale", scale=(5, 4, 3, 2, 1)), - ] + translation = [float(t) for t in range(1, 6)] + scale = [float(s) for s in range(5, 0, -1)] + array_name = "0" if image_name == "*" else image_name + new_scale = 10.0 with TemporaryDirectory() as temp_dir: store_path = os.path.join(temp_dir, "ome.zarr") with open_ome_zarr( - store_path, layout="fov", mode="w-", channel_names=channel_names + store_path, layout="fov", mode="w-", channel_names=["a", "b"] ) as dataset: - dataset.create_zeros(name="0", shape=shape, dtype=dtype) - dataset.set_transform(image="0", transform=transform) - dataset.set_scale(image="0", axis_name="z", new_scale=10.0) - assert dataset.scale[-3] == 10.0 - assert ( - dataset.metadata.multiscales[0] - .datasets[0] - .coordinate_transformations[0] - .translation[-1] - == 5 + dataset.create_zeros( + name=array_name, + shape=(1, 2, 4, 8, 16), + dtype=int, + transform=[ + TransformationMeta( + type="translation", translation=translation + ), + TransformationMeta(type="scale", scale=scale), + ], ) - with pytest.raises(ValueError): - dataset.set_scale(image="0", axis_name="z", new_scale=-1.0) - - assert dataset.zattrs["iohub"]["prior_z_scale"] == 3.0 + dataset.set_scale( + image=image_name, axis_name="z", new_scale=-1.0 + ) + with pytest.raises(KeyError): + dataset.set_scale( + image="nonexistent", axis_name="z", new_scale=9.0 + ) + assert dataset.scale[-3] == 3.0 + dataset.set_scale( + image=image_name, axis_name="z", new_scale=new_scale + ) + if image_name == "*": + assert dataset.scale[-3] == new_scale * 3.0 + else: + assert dataset.scale[-3] == new_scale + assert dataset.get_effective_translation(array_name) == translation + for tf in dataset.zattrs["iohub"]["previous_transforms"][0][ + "transforms" + ]: + if tf["type"] == "scale": + assert tf["scale"] == scale @given(channel_names=channel_names_st)