diff --git a/iohub/ngff/utils.py b/iohub/ngff/utils.py index 11621fc5..e1f50d85 100644 --- a/iohub/ngff/utils.py +++ b/iohub/ngff/utils.py @@ -4,7 +4,15 @@ from collections import defaultdict from functools import partial from pathlib import Path -from typing import Any, Callable, Literal, Sequence, Union +from typing import ( + Any, + Callable, + Literal, + Protocol, + Sequence, + Union, + runtime_checkable, +) from warnings import warn import click @@ -147,9 +155,48 @@ def create_empty_plate( if channel_name not in metadata_channel_names: position.append_channel(channel_name, resize_arrays=True) +@runtime_checkable +class TransformFunction(Protocol): + def __call__( + self, + data: NDArray, + input_time_index: int | None = None, + **kwargs: Any + ) -> NDArray: + """Transform image data. + + The function must take the CZYX or TCZYX NDArray as the first argument and return a CZYX or TCZYX NDArray. + Additional arguments are passed through **kwargs. + + Parameters + ---------- + data : NDArray + CZYX or TCZYX image data to transform + input_time_index : int, optional + Time index for time-dependent transformations. Defaults to None. + **kwargs + Additional transformation parameters + + Returns + ------- + NDArray + Transformed data with same shape as input. + + Examples + -------- + Double the Array + >>> def double(data: NDArray, **kwargs) -> NDArray: + ... return data * 2 + Apply Gaussian blur on the Array + >>> def gaussian_blur(data: NDArray, **kwargs) -> NDArray: + ... from skimage.filters import gaussian + ... return gaussian(data, **kwargs) + """ + ... + def _apply_transform_to_czyx( - func: Callable[[NDArray, Any], NDArray], + func: TransformFunction, input_position_path: Path, input_channel_indices: Union[list[int], slice], input_time_index: int, @@ -208,7 +255,7 @@ def _save_transformed( def apply_transform_to_czyx_and_save( - func: Callable[[NDArray, Any], NDArray], + func: TransformFunction, input_position_path: Path, output_position_path: Path, input_channel_indices: Union[list[int], slice], @@ -225,7 +272,7 @@ def apply_transform_to_czyx_and_save( Parameters ---------- - func : Callable[[NDArray, Any], NDArray] + func : TransformFunction The function to be applied to the data. func must take the CZYX NDArray as the first argument and return a CZXY NDArray. Additional arguments are passed through **kwargs. @@ -282,6 +329,7 @@ def apply_transform_to_czyx_and_save( input_time_index=input_time_index, **kwargs, ) + skipped: bool = True if transformed is not None: _save_transformed( transformed, @@ -289,17 +337,12 @@ def apply_transform_to_czyx_and_save( output_channel_indices=output_channel_indices, output_position_path=output_position_path, ) - _echo_finished( - time_index=input_time_index, - channel_index=input_channel_indices, - skipped=False, - ) - else: - _echo_finished( - time_index=input_time_index, - channel_index=input_channel_indices, - skipped=True, - ) + skipped = False + _echo_finished( + time_index=input_time_index, + channel_index=input_channel_indices, + skipped=skipped, + ) def _indices_to_shard_aligned_batches( @@ -367,7 +410,7 @@ def _slice_to_list(indices: list[int] | slice) -> list[int]: def apply_transform_to_tczyx_and_save( - func: Callable[[NDArray, Any], NDArray], + func: TransformFunction, input_position_path: Path, output_position_path: Path, input_channel_indices: list[int] | slice, @@ -382,7 +425,7 @@ def apply_transform_to_tczyx_and_save( Parameters ---------- - func : Callable[[NDArray, Any], NDArray] + func : TransformFunction The function to be applied to the data. func must take the TCZYX NDArray as the first argument and return a TCZXY NDArray. Additional arguments are passed through **kwargs. @@ -442,7 +485,7 @@ def apply_transform_to_tczyx_and_save( def process_single_position( - func: Callable[[NDArray, Any], NDArray], + func: TransformFunction, input_position_path: Path, output_position_path: Path, input_channel_indices: list[slice] | list[list[int]] | None = None, @@ -459,7 +502,7 @@ def process_single_position( Parameters ---------- - func : Callable[[NDArray, Any], NDArray] + func : TransformFunction The function to be applied to the data. func must take the CZYX NDArray as the first argument and return a CZXY NDArray. Additional arguments are passed through **kwargs.