Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 62 additions & 19 deletions iohub/ngff/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@
from collections import defaultdict
from functools import partial
from pathlib import Path
from typing import Any, Callable, Literal, Sequence, Union
from typing import (
Any,
Callable,
Literal,
Protocol,
Sequence,
Union,
runtime_checkable,
)
from warnings import warn

import click
Expand Down Expand Up @@ -147,9 +155,48 @@ def create_empty_plate(
if channel_name not in metadata_channel_names:
position.append_channel(channel_name, resize_arrays=True)

@runtime_checkable
class TransformFunction(Protocol):
def __call__(
self,
data: NDArray,
input_time_index: int | None = None,
**kwargs: Any
) -> NDArray:
"""Transform image data.

The function must take the CZYX or TCZYX NDArray as the first argument and return a CZYX or TCZYX NDArray.
Additional arguments are passed through **kwargs.

Parameters
----------
data : NDArray
CZYX or TCZYX image data to transform
input_time_index : int, optional
Time index for time-dependent transformations. Defaults to None.
**kwargs
Additional transformation parameters

Returns
-------
NDArray
Transformed data with same shape as input.

Examples
Comment on lines +179 to +185
Copy link
Copy Markdown
Contributor

@ziw-liu ziw-liu Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I understand this. Are these examples of functions that are compatible with the protocols?

Copy link
Copy Markdown
Collaborator Author

@srivarra srivarra Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, exactly.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But these functions will throw an error when input_time_index is given as an argument.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see that now, what about something like this?

def gaussian_blur(data: NDArray, input_time_index: None | int, **kwargs) -> NDArray:
	from skimage.filters import gaussian
	if input_time_index:
		data = data[input_time_index,...]
	return gaussian(data, **kwargs)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can test this by running it as a concrete parametrized test case.

--------
Double the Array
>>> def double(data: NDArray, **kwargs) -> NDArray:
... return data * 2
Apply Gaussian blur on the Array
>>> def gaussian_blur(data: NDArray, **kwargs) -> NDArray:
... from skimage.filters import gaussian
... return gaussian(data, **kwargs)
"""
...


def _apply_transform_to_czyx(
func: Callable[[NDArray, Any], NDArray],
func: TransformFunction,
input_position_path: Path,
input_channel_indices: Union[list[int], slice],
input_time_index: int,
Expand Down Expand Up @@ -208,7 +255,7 @@ def _save_transformed(


def apply_transform_to_czyx_and_save(
func: Callable[[NDArray, Any], NDArray],
func: TransformFunction,
input_position_path: Path,
output_position_path: Path,
input_channel_indices: Union[list[int], slice],
Expand All @@ -225,7 +272,7 @@ def apply_transform_to_czyx_and_save(

Parameters
----------
func : Callable[[NDArray, Any], NDArray]
func : TransformFunction
The function to be applied to the data.
func must take the CZYX NDArray as the first argument and return
a CZXY NDArray. Additional arguments are passed through **kwargs.
Expand Down Expand Up @@ -282,24 +329,20 @@ def apply_transform_to_czyx_and_save(
input_time_index=input_time_index,
**kwargs,
)
skipped: bool = True
if transformed is not None:
_save_transformed(
transformed,
output_time_indices=output_time_index,
output_channel_indices=output_channel_indices,
output_position_path=output_position_path,
)
_echo_finished(
time_index=input_time_index,
channel_index=input_channel_indices,
skipped=False,
)
else:
_echo_finished(
time_index=input_time_index,
channel_index=input_channel_indices,
skipped=True,
)
skipped = False
_echo_finished(
time_index=input_time_index,
channel_index=input_channel_indices,
skipped=skipped,
)


def _indices_to_shard_aligned_batches(
Expand Down Expand Up @@ -367,7 +410,7 @@ def _slice_to_list(indices: list[int] | slice) -> list[int]:


def apply_transform_to_tczyx_and_save(
func: Callable[[NDArray, Any], NDArray],
func: TransformFunction,
input_position_path: Path,
output_position_path: Path,
input_channel_indices: list[int] | slice,
Expand All @@ -382,7 +425,7 @@ def apply_transform_to_tczyx_and_save(

Parameters
----------
func : Callable[[NDArray, Any], NDArray]
func : TransformFunction
The function to be applied to the data.
func must take the TCZYX NDArray as the first argument and return
a TCZXY NDArray. Additional arguments are passed through **kwargs.
Expand Down Expand Up @@ -442,7 +485,7 @@ def apply_transform_to_tczyx_and_save(


def process_single_position(
func: Callable[[NDArray, Any], NDArray],
func: TransformFunction,
input_position_path: Path,
output_position_path: Path,
input_channel_indices: list[slice] | list[list[int]] | None = None,
Expand All @@ -459,7 +502,7 @@ def process_single_position(

Parameters
----------
func : Callable[[NDArray, Any], NDArray]
func : TransformFunction
The function to be applied to the data.
func must take the CZYX NDArray as the first argument and return
a CZXY NDArray. Additional arguments are passed through **kwargs.
Expand Down
Loading