Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions iohub/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,14 @@ def convert(input, output, grid_layout, chunks):
type=float,
help="New x scale",
)
@click.option("--image", required=False, help="Image name to set scale for")
Comment thread
ziw-liu marked this conversation as resolved.
Outdated
def set_scale(
input_position_dirpaths,
t_scale=None,
z_scale=None,
y_scale=None,
x_scale=None,
image=None,
):
"""Update scale metadata in OME-Zarr datasets.

Expand All @@ -138,6 +140,8 @@ def set_scale(

>> iohub set-scale -i input.zarr/*/*/* -z 2.0
"""
if image is None:
image = "0"
for input_position_dirpath in input_position_dirpaths:
with open_ome_zarr(
input_position_dirpath, layout="fov", mode="r+"
Expand All @@ -147,12 +151,7 @@ def set_scale(
):
if value is None:
continue
old_value = dataset.scale[dataset.get_axis_index(name)]
print(
f"Updating {input_position_dirpath} {name} scale from "
f"{old_value} to {value}."
)
dataset.set_scale("0", name, value)
dataset.set_scale(image, name, value)


@cli.command(name="rename-wells")
Expand Down
61 changes: 38 additions & 23 deletions iohub/ngff/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import math
import os
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Generator, Literal, Sequence, Type

Expand Down Expand Up @@ -1144,57 +1145,71 @@ def set_transform(
self.dump_meta()

def set_scale(
self,
image: str | Literal["*"],
axis_name: str,
new_scale: float,
self, image: str | Literal["*"], axis_name: str, new_scale: float
):
"""Set the scale for a named axis.
Either one image array or the whole FOV.

Parameters
----------
image : str | Literal[
image : str | Literal['*']
Name of one image array (e.g. "0") to transform,
or "*" for the whole FOV
axis_name : str
Name of the axis to set.
new_scale : float
Value of the new scale.
"""
if len(self.metadata.multiscales) > 1:
raise NotImplementedError(
"Cannot set scale for multi-resolution images."
)

if new_scale <= 0:
raise ValueError("New scale must be positive.")

raise ValueError(
f"New scale {axis_name}: {new_scale} is not positive!"
)
if image not in self and image != "*":
raise KeyError(f"Image {image} not found.")
axis_index = self.get_axis_index(axis_name)

# Update scale while preserving existing transforms
if image == "*":
transforms = (
self.metadata.multiscales[0].coordinate_transformations or []
)
else:
for dataset_meta in self.metadata.multiscales[0].datasets:
if dataset_meta.path == image:
transforms = dataset_meta.coordinate_transformations
break
# Append old scale to metadata
iohub_dict = {}
if "iohub" in self.zattrs:
iohub_dict = self.zattrs["iohub"]
iohub_dict.update({f"prior_{axis_name}_scale": self.scale[axis_index]})
self.zattrs["iohub"] = iohub_dict

# Update scale while preserving existing transforms
transforms = (
self.metadata.multiscales[0].datasets[0].coordinate_transformations
if "previous_transforms" not in iohub_dict:
iohub_dict["previous_transforms"] = []
iohub_dict["previous_transforms"].append(
{
"image": image,
"transforms": [
t.model_dump(**TO_DICT_SETTINGS) for t in transforms
],
"modified": datetime.now().isoformat(),
Comment thread
ziw-liu marked this conversation as resolved.
}
)
self.zattrs["iohub"] = iohub_dict
# Replace default identity transform with scale
if len(transforms) == 1 and transforms[0].type == "identity":
if transforms == [TransformationMeta(type="identity")]:
Comment thread
ziw-liu marked this conversation as resolved.
transforms = [TransformationMeta(type="scale", scale=[1] * 5)]
# Add scale transform if not present
if not any([transform.type == "scale" for transform in transforms]):
transforms.append(TransformationMeta(type="scale", scale=[1] * 5))

new_transforms = []
for transform in transforms:
if transform.type == "scale":
old_scale = transform.scale[axis_index]
transform.scale[axis_index] = new_scale

self.set_transform(image, transforms)
new_transforms.append(transform)
_logger.info(
f"Updating scale for axis {axis_name} "
f"from {old_scale} to {new_scale}."
)
self.set_transform(image, new_transforms)

def set_contrast_limits(self, channel_name: str, window: WindowDict):
"""Set the contrast limits for a channel.
Expand Down
29 changes: 13 additions & 16 deletions tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_cli_convert_ome_tiff(grid_layout, tmpdir):
assert "Converting" in result.output


def test_cli_set_scale():
def test_cli_set_scale(caplog):
with _temp_copy(hcs_ref) as store_path:
store_path = Path(store_path)
position_path = Path(store_path) / "B" / "03" / "0"
Expand Down Expand Up @@ -149,23 +149,16 @@ def test_cli_set_scale():
],
)
assert result_pos.exit_code == 0
assert "Updating" in result_pos.output

assert any("Updating" in record.message for record in caplog.records)
with open_ome_zarr(position_path, layout="fov") as output_dataset:
assert tuple(output_dataset.scale[-3:]) == (random_z, 0.5, 0.5)
assert output_dataset.scale != old_scale
assert (
output_dataset.zattrs["iohub"]["prior_x_scale"]
== old_scale[-1]
)
assert (
output_dataset.zattrs["iohub"]["prior_y_scale"]
== old_scale[-2]
)
assert (
output_dataset.zattrs["iohub"]["prior_z_scale"]
== old_scale[-3]
)
for i, record in enumerate(
output_dataset.zattrs["iohub"]["previous_transforms"]
):
for transform in record["transforms"]:
if transform["type"] == "scale":
assert transform["scale"][-3:][i] == old_scale[-3:][i]

# Test plate-expands-into-positions behavior
runner = CliRunner()
Expand All @@ -181,7 +174,11 @@ def test_cli_set_scale():
)
with open_ome_zarr(position_path, layout="fov") as output_dataset:
assert output_dataset.scale[-1] == 0.1
assert output_dataset.zattrs["iohub"]["prior_x_scale"] == 0.5
for transform in output_dataset.zattrs["iohub"][
"previous_transforms"
][-1]["transforms"]:
if transform["type"] == "scale":
assert transform["scale"][-1] == 0.5


def test_cli_rename_wells_help():
Expand Down
63 changes: 38 additions & 25 deletions tests/ngff/test_ngff.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,38 +642,51 @@ def test_set_transform_fov(ch_shape_dtype, arr_name):
]


@given(
ch_shape_dtype=_channels_and_random_5d_shape_and_dtype(),
)
@settings(deadline=None)
def test_set_scale(ch_shape_dtype):
@pytest.mark.parametrize("image_name", ["0", "1", "a", "*"])
def test_set_scale(image_name):
"""Test `iohub.ngff.Position.set_scale()`"""
channel_names, shape, dtype = ch_shape_dtype
transform = [
TransformationMeta(type="translation", translation=(1, 2, 3, 4, 5)),
TransformationMeta(type="scale", scale=(5, 4, 3, 2, 1)),
]
translation = [float(t) for t in range(1, 6)]
scale = [float(s) for s in range(5, 0, -1)]
array_name = "0" if image_name == "*" else image_name
new_scale = 10.0
with TemporaryDirectory() as temp_dir:
store_path = os.path.join(temp_dir, "ome.zarr")
with open_ome_zarr(
store_path, layout="fov", mode="w-", channel_names=channel_names
store_path, layout="fov", mode="w-", channel_names=["a", "b"]
) as dataset:
dataset.create_zeros(name="0", shape=shape, dtype=dtype)
dataset.set_transform(image="0", transform=transform)
dataset.set_scale(image="0", axis_name="z", new_scale=10.0)
assert dataset.scale[-3] == 10.0
assert (
dataset.metadata.multiscales[0]
.datasets[0]
.coordinate_transformations[0]
.translation[-1]
== 5
dataset.create_zeros(
name=array_name,
shape=(1, 2, 4, 8, 16),
dtype=int,
transform=[
TransformationMeta(
type="translation", translation=translation
),
TransformationMeta(type="scale", scale=scale),
],
)

with pytest.raises(ValueError):
dataset.set_scale(image="0", axis_name="z", new_scale=-1.0)

assert dataset.zattrs["iohub"]["prior_z_scale"] == 3.0
dataset.set_scale(
image=image_name, axis_name="z", new_scale=-1.0
)
with pytest.raises(KeyError):
dataset.set_scale(
image="nonexistent", axis_name="z", new_scale=9.0
)
assert dataset.scale[-3] == 3.0
dataset.set_scale(
image=image_name, axis_name="z", new_scale=new_scale
)
if image_name == "*":
assert dataset.scale[-3] == new_scale * 3.0
else:
assert dataset.scale[-3] == new_scale
assert dataset.get_effective_translation(array_name) == translation
for tf in dataset.zattrs["iohub"]["previous_transforms"][0][
"transforms"
]:
if tf["type"] == "scale":
assert tf["scale"] == scale


@given(channel_names=channel_names_st)
Expand Down