diff --git a/.gitignore b/.gitignore index 542e08e3b6..9c0554dca9 100644 --- a/.gitignore +++ b/.gitignore @@ -131,6 +131,7 @@ tests/testing_data/MedNIST* tests/testing_data/*Hippocampus* tests/testing_data/*.tiff tests/testing_data/schema.json +*.svg # clang format tool .clang-format-bin/ @@ -138,3 +139,6 @@ tests/testing_data/schema.json # VSCode .vscode/ *.zip + +# profiling results +*.prof diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 293f058acf..7502de5225 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -71,6 +71,7 @@ from .thread_buffer import ThreadBuffer, ThreadDataLoader from .torchscript_utils import load_net_with_metadata, save_net_with_metadata from .utils import ( + affine_to_spacing, compute_importance_map, compute_shape_offset, convert_tables_to_dicts, @@ -111,8 +112,8 @@ from multiprocessing.reduction import ForkingPickler def _rebuild_meta(cls, storage, metadata): - storage_offset, size, stride, meta_obj = metadata - t = cls([], meta=meta_obj, dtype=storage.dtype, device=storage.device) + storage_offset, size, stride, meta_obj, applied_operations = metadata + t = cls([], meta=meta_obj, applied_operations=applied_operations, dtype=storage.dtype, device=storage.device) t.set_(storage._untyped() if hasattr(storage, "_untyped") else storage, storage_offset, size, stride) return t @@ -120,7 +121,13 @@ def reduce_meta_tensor(meta_tensor): storage = meta_tensor.storage() if storage.is_cuda: raise NotImplementedError("sharing CUDA metatensor across processes not implemented") - metadata = (meta_tensor.storage_offset(), meta_tensor.size(), meta_tensor.stride(), meta_tensor.meta) + metadata = ( + meta_tensor.storage_offset(), + meta_tensor.size(), + meta_tensor.stride(), + meta_tensor.meta, + meta_tensor.applied_operations, + ) return _rebuild_meta, (type(meta_tensor), storage, metadata) ForkingPickler.register(MetaTensor, reduce_meta_tensor) diff --git a/monai/data/dataset_summary.py b/monai/data/dataset_summary.py index 14024c0ff9..6af46d11ea 100644 --- a/monai/data/dataset_summary.py +++ b/monai/data/dataset_summary.py @@ -18,9 +18,9 @@ from monai.config import KeysCollection from monai.data.dataloader import DataLoader from monai.data.dataset import Dataset +from monai.data.utils import affine_to_spacing from monai.transforms import concatenate -from monai.utils import convert_data_type -from monai.utils.enums import PostFix +from monai.utils import PostFix, convert_data_type DEFAULT_POST_FIX = PostFix.meta() @@ -84,7 +84,7 @@ def collect_meta_data(self): raise ValueError(f"To collect metadata for the dataset, key `{self.meta_key}` must exist in `data`.") self.all_meta_data.append(data[self.meta_key]) - def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: int = 3, percentile: float = 10.0): + def get_target_spacing(self, spacing_key: str = "affine", anisotropic_threshold: int = 3, percentile: float = 10.0): """ Calculate the target spacing according to all spacings. If the target spacing is very anisotropic, @@ -93,7 +93,7 @@ def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: After loading with `monai.DataLoader`, "pixdim" is in the form of `torch.Tensor` with size `(batch_size, 8)`. Args: - spacing_key: key of spacing in metadata (default: ``pixdim``). + spacing_key: key of the affine used to compute spacing in metadata (default: ``affine``). anisotropic_threshold: threshold to decide if the target spacing is anisotropic (default: ``3``). percentile: for anisotropic target spacing, use the percentile of all spacings of the anisotropic axis to replace that axis. @@ -103,7 +103,8 @@ def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: self.collect_meta_data() if spacing_key not in self.all_meta_data[0]: raise ValueError("The provided spacing_key is not in self.all_meta_data.") - all_spacings = concatenate(to_cat=[data[spacing_key][:, 1:4] for data in self.all_meta_data], axis=0) + spacings = [affine_to_spacing(data[spacing_key][0], 3)[None] for data in self.all_meta_data] + all_spacings = concatenate(to_cat=spacings, axis=0) all_spacings, *_ = convert_data_type(data=all_spacings, output_type=np.ndarray, wrap_sequence=True) target_spacing = np.median(all_spacings, axis=0) diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index 2e5c38938a..0f404dcac7 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -11,8 +11,11 @@ from __future__ import annotations +import itertools from copy import deepcopy -from typing import Any, Callable, Sequence +from typing import Any, Iterable + +from monai.utils.enums import TraceKeys _TRACK_META = True @@ -72,77 +75,79 @@ class MetaObj: """ def __init__(self): - self._meta: dict = self.get_default_meta() + self._meta: dict = MetaObj.get_default_meta() + self._applied_operations: list = MetaObj.get_default_applied_operations() self._is_batch: bool = False @staticmethod - def flatten_meta_objs(args: Sequence[Any]) -> list[MetaObj]: + def flatten_meta_objs(*args: Iterable): """ - Recursively flatten input and return all instances of `MetaObj` as a single - list. This means that for both `torch.add(a, b)`, `torch.stack([a, b])` (and + Recursively flatten input and yield all instances of `MetaObj`. + This means that for both `torch.add(a, b)`, `torch.stack([a, b])` (and their numpy equivalents), we return `[a, b]` if both `a` and `b` are of type `MetaObj`. Args: - args: Sequence of inputs to be flattened. + args: Iterables of inputs to be flattened. Returns: list of nested `MetaObj` from input. """ - out = [] - for a in args: + for a in itertools.chain(*args): if isinstance(a, (list, tuple)): - out += MetaObj.flatten_meta_objs(a) + yield from MetaObj.flatten_meta_objs(a) elif isinstance(a, MetaObj): - out.append(a) - return out + yield a - def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None: + def _copy_attr(self, attributes: list[str], input_objs, defaults: list, deep_copy: bool) -> None: """ - Copy an attribute from the first in a list of `MetaObj`. In the case of + Copy attributes from the first in a list of `MetaObj`. In the case of `torch.add(a, b)`, both `a` and `b` could be `MetaObj` or something else, so check them all. Copy the first to `self`. We also perform a deep copy of the data if desired. Args: - attribute: string corresponding to attribute to be copied (e.g., `meta`). - input_objs: List of `MetaObj`. We'll copy the attribute from the first one + attributes: a sequence of strings corresponding to attributes to be copied (e.g., `['meta']`). + input_objs: an iterable of `MetaObj` instances. We'll copy the attribute from the first one that contains that particular attribute. - default_fn: If none of `input_objs` have the attribute that we're - interested in, then use this default function (e.g., `lambda: {}`.) - deep_copy: Should the attribute be deep copied? See `_copy_meta`. + defaults: If none of `input_objs` have the attribute that we're + interested in, then use this default value/function (e.g., `lambda: {}`.) + the defaults must be the same length as `attributes`. + deep_copy: whether to deep copy the corresponding attribute. Returns: Returns `None`, but `self` should be updated to have the copied attribute. """ - attributes = [getattr(i, attribute) for i in input_objs if hasattr(i, attribute)] - if len(attributes) > 0: - val = attributes[0] - if deep_copy: - val = deepcopy(val) - setattr(self, attribute, val) - else: - setattr(self, attribute, default_fn()) - - def _copy_meta(self, input_objs: list[MetaObj]) -> None: + found = [False] * len(attributes) + for i, (idx, a) in itertools.product(input_objs, enumerate(attributes)): + if not found[idx] and hasattr(i, a): + setattr(self, a, deepcopy(getattr(i, a)) if deep_copy else getattr(i, a)) + found[idx] = True + if all(found): + return + for a, f, d in zip(attributes, found, defaults): + if not f: + setattr(self, a, d() if callable(defaults) else d) + return + + def _copy_meta(self, input_objs, deep_copy=False) -> None: """ - Copy metadata from a list of `MetaObj`. For a given attribute, we copy the + Copy metadata from an iterable of `MetaObj` instances. For a given attribute, we copy the adjunct data from the first element in the list containing that attribute. - If there has been a change in `id` (e.g., `a=b+c`), then deepcopy. Else (e.g., - `a+=1`), then don't. - Args: input_objs: list of `MetaObj` to copy data from. """ - id_in = id(input_objs[0]) if len(input_objs) > 0 else None - deep_copy = id(self) != id_in - self._copy_attr("meta", input_objs, self.get_default_meta, deep_copy) - self._copy_attr("applied_operations", input_objs, self.get_default_applied_operations, deep_copy) - self.is_batch = input_objs[0].is_batch if len(input_objs) > 0 else False + self._copy_attr( + ["meta", "applied_operations"], + input_objs, + [MetaObj.get_default_meta(), MetaObj.get_default_applied_operations()], + deep_copy, + ) - def get_default_meta(self) -> dict: + @staticmethod + def get_default_meta() -> dict: """Get the default meta. Returns: @@ -150,7 +155,8 @@ def get_default_meta(self) -> dict: """ return {} - def get_default_applied_operations(self) -> list: + @staticmethod + def get_default_applied_operations() -> list: """Get the default applied operations. Returns: @@ -180,21 +186,29 @@ def __repr__(self) -> str: @property def meta(self) -> dict: """Get the meta.""" - return self._meta + return self._meta if hasattr(self, "_meta") else MetaObj.get_default_meta() @meta.setter - def meta(self, d: dict) -> None: + def meta(self, d) -> None: """Set the meta.""" + if d == TraceKeys.NONE: + self._meta = MetaObj.get_default_meta() self._meta = d @property def applied_operations(self) -> list: """Get the applied operations.""" - return self._applied_operations + if hasattr(self, "_applied_operations"): + return self._applied_operations + return MetaObj.get_default_applied_operations() @applied_operations.setter - def applied_operations(self, t: list) -> None: + def applied_operations(self, t) -> None: """Set the applied operations.""" + if t == TraceKeys.NONE: + # received no operations when decollating a batch + self._applied_operations = MetaObj.get_default_applied_operations() + return self._applied_operations = t def push_applied_operation(self, t: Any) -> None: @@ -206,7 +220,7 @@ def pop_applied_operation(self) -> Any: @property def is_batch(self) -> bool: """Return whether object is part of batch or not.""" - return self._is_batch + return self._is_batch if hasattr(self, "_is_batch") else False @is_batch.setter def is_batch(self, val: bool) -> None: diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 9ba86300a7..a993a5e464 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -13,14 +13,15 @@ import warnings from copy import deepcopy -from typing import Any, Callable, Sequence +from typing import Any, Sequence import torch from monai.config.type_definitions import NdarrayTensor from monai.data.meta_obj import MetaObj, get_track_meta -from monai.data.utils import decollate_batch, list_data_collate, remove_extra_metadata +from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata from monai.utils.enums import PostFix +from monai.utils.type_conversion import convert_to_tensor __all__ = ["MetaTensor"] @@ -125,7 +126,7 @@ def __init__( elif isinstance(x, MetaTensor): self.applied_operations = x.applied_operations else: - self.applied_operations = self.get_default_applied_operations() + self.applied_operations = MetaObj.get_default_applied_operations() # if we are creating a new MetaTensor, then deep copy attributes if isinstance(x, torch.Tensor) and not isinstance(x, MetaTensor): @@ -133,11 +134,12 @@ def __init__( self.applied_operations = deepcopy(self.applied_operations) self.affine = self.affine.to(self.device) - def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None: - super()._copy_attr(attribute, input_objs, default_fn, deep_copy) - val = getattr(self, attribute) - if isinstance(val, torch.Tensor): - setattr(self, attribute, val.to(self.device)) + def _copy_attr(self, attributes: list[str], input_objs, defaults: list, deep_copy: bool) -> None: + super()._copy_attr(attributes, input_objs, defaults, deep_copy) + for a in attributes: + val = getattr(self, a) + if isinstance(val, torch.Tensor): + setattr(self, a, val.to(self.device)) @staticmethod def update_meta(rets: Sequence, func, args, kwargs) -> Sequence: @@ -172,6 +174,7 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence: """ out = [] metas = None + is_batch = any(x.is_batch for x in MetaObj.flatten_meta_objs(args, kwargs.values()) if hasattr(x, "is_batch")) for idx, ret in enumerate(rets): # if not `MetaTensor`, nothing to do. if not isinstance(ret, MetaTensor): @@ -181,30 +184,34 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence: ret = ret.as_tensor() # else, handle the `MetaTensor` metadata. else: - meta_args = MetaObj.flatten_meta_objs(list(args) + list(kwargs.values())) - ret._copy_meta(meta_args) + meta_args = MetaObj.flatten_meta_objs(args, kwargs.values()) # type: ignore + ret._copy_meta(meta_args, deep_copy=not is_batch) + ret.is_batch = is_batch + # the following is not implemented but the network arch may run into this case: + # if func == torch.cat and any(m.is_batch if hasattr(m, "is_batch") else False for m in meta_args): + # raise NotImplementedError("torch.cat is not implemented for batch of MetaTensors.") # If we have a batch of data, then we need to be careful if a slice of # the data is returned. Depending on how the data are indexed, we return # some or all of the metadata, and the return object may or may not be a # batch of data (e.g., `batch[:,-1]` versus `batch[0]`). - if ret.is_batch: - # only decollate metadata once - if metas is None: - metas = decollate_batch(ret.meta) + if is_batch: # if indexing e.g., `batch[0]` if func == torch.Tensor.__getitem__: - idx = args[1] - if isinstance(idx, Sequence): - idx = idx[0] + batch_idx = args[1] + if isinstance(batch_idx, Sequence): + batch_idx = batch_idx[0] # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the # first element will be `slice(None, None, None)` and `Ellipsis`, # respectively. Don't need to do anything with the metadata. - if idx not in (slice(None, None, None), Ellipsis): - meta = metas[idx] + if batch_idx not in (slice(None, None, None), Ellipsis): + # only decollate metadata once + if metas is None: + metas = decollate_batch(ret.meta) + meta = metas[batch_idx] # if using e.g., `batch[0:2]`, then `is_batch` should still be # `True`. Also re-collate the remaining elements. - if isinstance(meta, list) and len(meta) > 1: + if isinstance(meta, list): ret.meta = list_data_collate(meta) # if using e.g., `batch[0]` or `batch[0, 1]`, then return single # element from batch, and set `is_batch` to `False`. @@ -222,6 +229,8 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence: else: dim = 0 if dim == 0: + if metas is None: + metas = decollate_batch(ret.meta) ret.meta = metas[idx] ret.is_batch = False @@ -242,6 +251,19 @@ def __torch_function__(cls, func, types, args=(), kwargs=None) -> Any: # we might have 1 or multiple outputs. Might be MetaTensor, might be something # else (e.g., `__repr__` returns a string). # Convert to list (if necessary), process, and at end remove list if one was added. + if ( + hasattr(torch, "return_types") + and hasattr(func, "__name__") + and hasattr(torch.return_types, func.__name__) + and isinstance(getattr(torch.return_types, func.__name__), type) + and isinstance(ret, getattr(torch.return_types, func.__name__)) + ): + # for torch.max(torch.tensor(1.0), dim=0), the return type is named-tuple like + out_items = MetaTensor.update_meta(ret, func, args, kwargs) + for idx in range(ret.n_fields): + ret[idx].meta = out_items[idx].meta + ret[idx].applied_operations = out_items[idx].applied_operations + return ret if isinstance(ret, (str, bytes)) or not isinstance(ret, Sequence): ret = [ret] unpack = True @@ -263,6 +285,7 @@ def as_tensor(self) -> torch.Tensor: def as_dict(self, key: str) -> dict: """ Get the object as a dictionary for backwards compatibility. + This method makes a copy of the objects. Args: key: Base key to store main data. The key for the metadata will be @@ -273,7 +296,7 @@ def as_dict(self, key: str) -> dict: the metadata. """ return { - key: self.as_tensor(), + key: self.as_tensor().clone().detach(), PostFix.meta(key): deepcopy(self.meta), PostFix.transforms(key): deepcopy(self.applied_operations), } @@ -281,13 +304,18 @@ def as_dict(self, key: str) -> dict: @property def affine(self) -> torch.Tensor: """Get the affine.""" - return self.meta["affine"] # type: ignore + return self.meta.get("affine", self.get_default_affine()) # type: ignore @affine.setter def affine(self, d: NdarrayTensor) -> None: """Set the affine.""" self.meta["affine"] = torch.as_tensor(d, device=self.device) + @property + def pixdim(self): + """Get the spacing""" + return affine_to_spacing(self.affine) + def new_empty(self, size, dtype=None, device=None, requires_grad=False): """ must be defined for deepcopy to work @@ -313,7 +341,7 @@ def ensure_torch_and_prune_meta(im: NdarrayTensor, meta: dict): By default, a `MetaTensor` is returned. However, if `get_track_meta()` is `False`, a `torch.Tensor` is returned. """ - img = torch.as_tensor(im) + img = convert_to_tensor(im) # potentially ascontiguousarray # if not tracking metadata, return `torch.Tensor` if not get_track_meta() or meta is None: @@ -321,7 +349,7 @@ def ensure_torch_and_prune_meta(im: NdarrayTensor, meta: dict): # ensure affine is of type `torch.Tensor` if "affine" in meta: - meta["affine"] = torch.as_tensor(meta["affine"]) + meta["affine"] = convert_to_tensor(meta["affine"]) # remove any superfluous metadata. remove_extra_metadata(meta) diff --git a/monai/data/png_writer.py b/monai/data/png_writer.py index b1fe7eb327..9fb463e9b9 100644 --- a/monai/data/png_writer.py +++ b/monai/data/png_writer.py @@ -14,7 +14,14 @@ import numpy as np from monai.transforms.spatial.array import Resize -from monai.utils import InterpolateMode, deprecated, ensure_tuple_rep, look_up_option, optional_import +from monai.utils import ( + InterpolateMode, + convert_data_type, + deprecated, + ensure_tuple_rep, + look_up_option, + optional_import, +) Image, _ = optional_import("PIL", name="Image") @@ -74,9 +81,9 @@ def write_png( if scale is not None: data = np.clip(data, 0.0, 1.0) # png writer only can scale data in range [0, 1] if scale == np.iinfo(np.uint8).max: - data = (scale * data).astype(np.uint8, copy=False) + data = convert_data_type((scale * data), np.ndarray, dtype=np.uint8, drop_meta=True)[0] elif scale == np.iinfo(np.uint16).max: - data = (scale * data).astype(np.uint16, copy=False) + data = convert_data_type((scale * data), np.ndarray, dtype=np.uint16, drop_meta=True)[0] else: raise ValueError(f"Unsupported scale: {scale}, available options are [255, 65535]") diff --git a/monai/data/utils.py b/monai/data/utils.py index 6f18e566e0..dd863c4898 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -36,6 +36,7 @@ Method, NumpyPadMode, PytorchPadMode, + TraceKeys, convert_data_type, convert_to_dst_type, ensure_tuple, @@ -391,6 +392,24 @@ def dev_collate(batch, level: int = 1, logger_name: str = "dev_collate"): return +def collate_meta_tensor(batch): + """collate a sequence of meta tensor sequences/dictionaries into + a single batched metatensor or a dictionary of batched metatensor""" + if not isinstance(batch, Sequence): + raise NotImplementedError() + elem_0 = first(batch) + if isinstance(elem_0, MetaObj): + collated = default_collate(batch) + collated.meta = default_collate([i.meta or TraceKeys.NONE for i in batch]) + collated.applied_operations = [i.applied_operations or TraceKeys.NONE for i in batch] + collated.is_batch = True + return collated + if isinstance(elem_0, Mapping): + return {k: collate_meta_tensor([d[k] for d in batch]) for k in elem_0} + # no more recursive search for MetaTensor + return default_collate(batch) + + def list_data_collate(batch: Sequence): """ Enhancement for PyTorch DataLoader default collate. @@ -410,15 +429,9 @@ def list_data_collate(batch: Sequence): for k in elem: key = k data_for_batch = [d[key] for d in data] - ret[key] = default_collate(data_for_batch) - if isinstance(ret[key], MetaObj) and all(isinstance(d, MetaObj) for d in data_for_batch): - ret[key].meta = list_data_collate([i.meta for i in data_for_batch]) - ret[key].is_batch = True + ret[key] = collate_meta_tensor(data_for_batch) else: - ret = default_collate(data) - if isinstance(ret, MetaObj) and all(isinstance(d, MetaObj) for d in data): - ret.meta = list_data_collate([i.meta for i in data]) - ret.is_batch = True + ret = collate_meta_tensor(data) return ret except RuntimeError as re: re_str = str(re) @@ -529,7 +542,9 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None): """ if batch is None: return batch - if isinstance(batch, (float, int, str, bytes)): + if isinstance(batch, (float, int, str, bytes)) or ( + type(batch).__module__ == "numpy" and not isinstance(batch, Iterable) + ): return batch if isinstance(batch, torch.Tensor): if detach: @@ -538,11 +553,15 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None): return batch.item() if detach else batch out_list = torch.unbind(batch, dim=0) # if of type MetaObj, decollate the metadata - if isinstance(batch, MetaObj) and all(isinstance(i, MetaObj) for i in out_list): - metas = decollate_batch(batch.meta) - for i in range(len(out_list)): - out_list[i].meta = metas[i] # type: ignore - out_list[i].is_batch = False # type: ignore + if isinstance(batch, MetaObj): + for t, m in zip(out_list, decollate_batch(batch.meta)): + if isinstance(t, MetaObj): + t.meta = m + t.is_batch = False + for t, m in zip(out_list, batch.applied_operations): + if isinstance(t, MetaObj): + t.applied_operations = m + t.is_batch = False if out_list[0].ndim == 0 and detach: return [t.item() for t in out_list] return list(out_list) @@ -643,6 +662,8 @@ def affine_to_spacing(affine: NdarrayTensor, r: int = 3, dtype=float, suppress_z Returns: an `r` dimensional vector of spacing. """ + if len(affine.shape) != 2 or affine.shape[0] != affine.shape[1]: + raise ValueError(f"affine must be a square matrix, got {affine.shape}.") _affine, *_ = convert_to_dst_type(affine[:r, :r], dst=affine, dtype=dtype) if isinstance(_affine, torch.Tensor): spacing = torch.sqrt(torch.sum(_affine * _affine, dim=0)) @@ -835,7 +856,7 @@ def to_affine_nd(r: Union[np.ndarray, int], affine: NdarrayTensor, dtype=np.floa an (r+1) x (r+1) matrix (tensor or ndarray depends on the input ``affine`` data type) """ - affine_np = convert_data_type(affine, output_type=np.ndarray, dtype=dtype, wrap_sequence=True)[0] + affine_np = convert_data_type(affine, output_type=np.ndarray, dtype=dtype, wrap_sequence=True, drop_meta=True)[0] affine_np = affine_np.copy() if affine_np.ndim != 2: raise ValueError(f"affine must have 2 dimensions, got {affine_np.ndim}.") diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index c4c4bd891c..b7e13323ec 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -15,12 +15,14 @@ import torch import torch.nn.functional as F +from monai.data.meta_tensor import MetaTensor from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size from monai.transforms import Resize from monai.utils import ( BlendMode, PytorchPadMode, convert_data_type, + convert_to_dst_type, ensure_tuple, fall_back_tuple, look_up_option, @@ -172,7 +174,9 @@ def sliding_window_inference( [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win]) for idx in slice_range ] - window_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device) + window_data = torch.cat( + [convert_data_type(inputs[win_slice], torch.Tensor, drop_meta=True)[0] for win_slice in unravel_slice] + ).to(sw_device) seg_prob_out = predictor(window_data, *args, **kwargs) # batched patch segmentation # convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory. @@ -272,7 +276,10 @@ def sliding_window_inference( final_output = dict(zip(dict_key, output_image_list)) else: final_output = tuple(output_image_list) # type: ignore - return final_output[0] if is_tensor_output else final_output # type: ignore + final_output = final_output[0] if is_tensor_output else final_output # type: ignore + if isinstance(inputs, MetaTensor): + final_output = convert_to_dst_type(final_output, inputs)[0] # type: ignore + return final_output def _get_scan_interval( diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index faf5093305..4722f0f040 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -17,7 +17,7 @@ from monai.transforms.croppad.array import SpatialCrop from monai.transforms.utils import generate_spatial_bounding_box -from monai.utils import MetricReduction, look_up_option, optional_import +from monai.utils import MetricReduction, convert_data_type, look_up_option, optional_import binary_erosion, _ = optional_import("scipy.ndimage.morphology", name="binary_erosion") distance_transform_edt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_edt") @@ -103,12 +103,7 @@ def do_metric_reduction(f: torch.Tensor, reduction: Union[MetricReduction, str] return f, not_nans -def get_mask_edges( - seg_pred: Union[np.ndarray, torch.Tensor], - seg_gt: Union[np.ndarray, torch.Tensor], - label_idx: int = 1, - crop: bool = True, -) -> Tuple[np.ndarray, np.ndarray]: +def get_mask_edges(seg_pred, seg_gt, label_idx: int = 1, crop: bool = True) -> Tuple[np.ndarray, np.ndarray]: """ Do binary erosion and use XOR for input to get the edges. This function is helpful to further calculate metrics such as Average Surface @@ -160,9 +155,8 @@ def get_mask_edges( seg_pred, seg_gt = np.expand_dims(seg_pred, axis=channel_dim), np.expand_dims(seg_gt, axis=channel_dim) box_start, box_end = generate_spatial_bounding_box(np.asarray(seg_pred | seg_gt)) cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) - seg_pred, seg_gt = np.squeeze(cropper(seg_pred), axis=channel_dim), np.squeeze( - cropper(seg_gt), axis=channel_dim - ) + seg_pred = convert_data_type(np.squeeze(cropper(seg_pred), axis=channel_dim), np.ndarray, drop_meta=True)[0] + seg_gt = convert_data_type(np.squeeze(cropper(seg_gt), axis=channel_dim), np.ndarray, drop_meta=True)[0] # Do binary erosion and use XOR to get edges edges_pred = binary_erosion(seg_pred) ^ seg_pred diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 847614adfe..24db2a871c 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -105,6 +105,7 @@ "convert_pad_mode", "convert_to_contiguous", "get_unique_labels", + "scale_affine", ] @@ -1185,16 +1186,13 @@ def map_spatial_axes( """ if spatial_axes is None: - spatial_axes_ = list(range(1, img_ndim) if channel_first else range(img_ndim - 1)) - - else: - spatial_axes_ = [] - for a in ensure_tuple(spatial_axes): - if channel_first: - spatial_axes_.append(a if a < 0 else a + 1) - else: - spatial_axes_.append(a - 1 if a < 0 else a) - + return list(range(1, img_ndim) if channel_first else range(img_ndim - 1)) + spatial_axes_ = [] + for a in ensure_tuple(spatial_axes): + if channel_first: + spatial_axes_.append(a % img_ndim if a < 0 else a + 1) + else: + spatial_axes_.append((a - 1) % (img_ndim - 1) if a < 0 else a) return spatial_axes_ @@ -1529,7 +1527,7 @@ def print_table_column(name, torch, numpy, color=Colors.none): print_color(f"Number of uncategorised: {n_uncategorized}", Colors.red) -def convert_pad_mode(dst: NdarrayOrTensor, mode: Union[NumpyPadMode, PytorchPadMode, str]): +def convert_pad_mode(dst: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]]): """ Utility to convert padding mode between numpy array and PyTorch Tensor. @@ -1573,5 +1571,30 @@ def convert_to_contiguous(data, **kwargs): return data +def scale_affine(affine, spatial_size, new_spatial_size, centered: bool = True): + """ + Scale the affine matrix according to the new spatial size. + + Args: + affine: affine matrix to scale. + spatial_size: original spatial size. + new_spatial_size: new spatial size. + centered: whether the scaling is with respect to + the image center (True, default) or corner (False). + + Returns: + Scaled affine matrix. + + """ + if spatial_size == new_spatial_size: + return affine + r = len(affine) - 1 + s = np.array([float(o) / float(max(n, 1)) for o, n in zip(spatial_size, new_spatial_size)]) + scale = create_scale(r, s.tolist()) + if centered: + scale[:r, -1] = (np.diag(scale)[:r] - 1) / 2 # type: ignore + return affine @ convert_to_dst_type(scale, affine)[0] + + if __name__ == "__main__": print_transform_backends() diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 2aedc77dd7..5e84efafe7 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -376,7 +376,7 @@ def mode(x: NdarrayTensor, dim: int = -1, to_long: bool = True) -> NdarrayTensor to_long: convert input to long before performing mode. """ dtype = torch.int64 if to_long else None - x_t, *_ = convert_data_type(x, torch.Tensor, dtype=dtype) + x_t, *_ = convert_data_type(x, torch.Tensor, dtype=dtype, drop_meta=True) o_t = torch.mode(x_t, dim).values o, *_ = convert_to_dst_type(o_t, x) return o @@ -389,3 +389,14 @@ def unique(x: NdarrayTensor) -> NdarrayTensor: x: array/tensor """ return torch.unique(x) if isinstance(x, torch.Tensor) else np.unique(x) # type: ignore + + +def linalg_inv(x: NdarrayTensor) -> NdarrayTensor: + """`torch.linalg.inv` with equivalent implementation for numpy. + + Args: + x: array/tensor + """ + if isinstance(x, torch.Tensor) and hasattr(torch, "inverse"): # pytorch 1.7.0 + return torch.inverse(x) # type: ignore + return torch.linalg.inv(x) if isinstance(x, torch.Tensor) else np.linalg.inv(x) # type: ignore diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 33b2a5fa2a..f53cfdaef0 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -93,6 +93,7 @@ convert_to_cupy, convert_to_dst_type, convert_to_list, + convert_to_meta_tensor, convert_to_numpy, convert_to_tensor, dtype_numpy_to_torch, diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 99f645a704..fc38dc5056 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -26,7 +26,7 @@ import numpy as np import torch -from monai.config.type_definitions import NdarrayOrTensor, PathLike +from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike from monai.utils.module import version_leq __all__ = [ @@ -155,7 +155,7 @@ def ensure_tuple_rep(tup: Any, dim: int) -> Tuple[Any, ...]: def fall_back_tuple( - user_provided: Any, default: Union[Sequence, np.ndarray], func: Callable = lambda x: x and x > 0 + user_provided: Any, default: Union[Sequence, NdarrayTensor], func: Callable = lambda x: x and x > 0 ) -> Tuple[Any, ...]: """ Refine `user_provided` according to the `default`, and returns as a validated tuple. @@ -367,6 +367,7 @@ class ImageMetaKey: FILENAME_OR_OBJ = "filename_or_obj" PATCH_INDEX = "patch_index" + SPATIAL_SHAPE = "spatial_shape" def has_option(obj, keywords: Union[str, Sequence[str]]) -> bool: diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index a6cd2522d7..2d88a269fe 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -10,11 +10,13 @@ # limitations under the License. import re +from copy import deepcopy from typing import Any, Optional, Sequence, Tuple, Type, Union import numpy as np import torch +import monai from monai.config.type_definitions import DtypeLike, NdarrayTensor from monai.utils import optional_import @@ -32,6 +34,7 @@ "convert_to_cupy", "convert_to_numpy", "convert_to_tensor", + "convert_to_meta_tensor", "convert_to_dst_type", ] @@ -70,7 +73,7 @@ def get_equivalent_dtype(dtype, data_type): """ if dtype is None: return None - if data_type is torch.Tensor: + if data_type is torch.Tensor or data_type.__name__ == "MetaTensor": if isinstance(dtype, torch.dtype): # already a torch dtype and target `data_type` is torch.Tensor return dtype @@ -111,8 +114,11 @@ def convert_to_tensor( wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`. If `True`, then `[1, 2]` -> `tensor([1, 2])`. + """ if isinstance(data, torch.Tensor): + if isinstance(data, monai.data.MetaTensor): + data = data.as_tensor() return data.to(dtype=dtype, device=device, memory_format=torch.contiguous_format) # type: ignore if isinstance(data, np.ndarray): # skip array of string classes and object, refer to: @@ -137,6 +143,59 @@ def convert_to_tensor( return data +def convert_to_meta_tensor( + data, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, wrap_sequence: bool = False +): + """ + Utility to convert the input data to a MetaTensor. If passing a dictionary, list or tuple, + recursively check every item and convert it to MetaTensor. + + Args: + data: input data can be MetaTensor, PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. + will convert Tensor, Numpy array, float, int, bool to Tensor, strings and objects keep the original. + for dictionary, list or tuple, convert every item to a Tensor if applicable. + dtype: target data type to when converting to Tensor. + device: target device to put the converted Tensor data. + wrap_sequence: if `False`, then lists will recursively call this function. + E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`. If `True`, then `[1, 2]` -> `tensor([1, 2])`. + + """ + if isinstance(data, torch.Tensor): + out = data.to(dtype=dtype, device=device, memory_format=torch.contiguous_format) # type: ignore + if not isinstance(out, monai.data.MetaTensor): + out = monai.data.MetaTensor(out) + return out + if isinstance(data, np.ndarray): + # skip array of string classes and object, refer to: + # https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/_utils/collate.py#L13 + if re.search(r"[SaUO]", data.dtype.str) is None: + # numpy array with 0 dims is also sequence iterable, + # `ascontiguousarray` will add 1 dim if img has no dim, so we only apply on data with dims + if data.ndim > 0: + data = np.ascontiguousarray(data) + return monai.data.MetaTensor(torch.as_tensor(data, dtype=dtype, device=device)) # type: ignore + elif (has_cp and isinstance(data, cp_ndarray)) or isinstance(data, (float, int, bool)): + return monai.data.MetaTensor(torch.as_tensor(data, dtype=dtype, device=device)) # type: ignore + elif isinstance(data, list): + list_ret = [convert_to_meta_tensor(i, dtype=dtype, device=device) for i in data] + return ( + monai.data.MetaTensor(torch.as_tensor(list_ret, dtype=dtype, device=device)) # type: ignore + if wrap_sequence + else list_ret + ) + elif isinstance(data, tuple): + tuple_ret = tuple(convert_to_meta_tensor(i, dtype=dtype, device=device) for i in data) + return ( + monai.data.MetaTensor(torch.as_tensor(tuple_ret, dtype=dtype, device=device)) # type: ignore + if wrap_sequence + else tuple_ret + ) + elif isinstance(data, dict): + return {k: convert_to_meta_tensor(v, dtype=dtype, device=device) for k, v in data.items()} + + return data + + def convert_to_numpy(data, dtype: DtypeLike = None, wrap_sequence: bool = False): """ Utility to convert the input data to a numpy array. If passing a dictionary, list or tuple, @@ -212,6 +271,7 @@ def convert_data_type( device: Optional[torch.device] = None, dtype: Union[DtypeLike, torch.dtype] = None, wrap_sequence: bool = False, + drop_meta: bool = True, ) -> Tuple[NdarrayTensor, type, Optional[torch.device]]: """ Convert to `torch.Tensor`/`np.ndarray` from `torch.Tensor`/`np.ndarray`/`float`/`int` etc. @@ -225,6 +285,10 @@ def convert_data_type( If left blank, it remains unchanged. wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`. + drop_meta: whether to drop the meta information of the input data, default to `True`. + If `True`, then the meta information will be dropped quietly, unless the output type is MetaTensor. + If `False`, converting a MetaTensor into a non-tensor instance will raise an error. + Returns: modified data, orig_type, orig_device @@ -238,7 +302,9 @@ def convert_data_type( """ orig_type: type - if isinstance(data, torch.Tensor): + if isinstance(data, monai.data.MetaTensor): + orig_type = monai.data.MetaTensor + elif isinstance(data, torch.Tensor): orig_type = torch.Tensor elif isinstance(data, np.ndarray): orig_type = np.ndarray @@ -253,7 +319,19 @@ def convert_data_type( dtype_ = get_equivalent_dtype(dtype, output_type) + if not drop_meta and not issubclass(output_type, monai.data.MetaObj) and isinstance(data, monai.data.MetaObj): + # input has a MetaObj, user chose keep the metadata, but the output type cannot take a MetaObj. + if issubclass(output_type, torch.Tensor): + # user-specified MetaTensor to torch tensor keep the MetaTensor type, for backward compatibility + output_type = type(data) # type: ignore + else: + raise RuntimeError(f"the specified output_type {output_type} cannot have the metaobj, but drop_meta=False.") + data_: NdarrayTensor + + if issubclass(output_type, monai.data.MetaTensor): + data_ = convert_to_meta_tensor(data, dtype=dtype_, device=device, wrap_sequence=wrap_sequence) + return data_, orig_type, orig_device if issubclass(output_type, torch.Tensor): data_ = convert_to_tensor(data, dtype=dtype_, device=device, wrap_sequence=wrap_sequence) return data_, orig_type, orig_device @@ -267,7 +345,11 @@ def convert_data_type( def convert_to_dst_type( - src: Any, dst: NdarrayTensor, dtype: Union[DtypeLike, torch.dtype, None] = None, wrap_sequence: bool = False + src: Any, + dst: NdarrayTensor, + dtype: Union[DtypeLike, torch.dtype, None] = None, + wrap_sequence: bool = False, + drop_meta: bool = True, ) -> Tuple[NdarrayTensor, type, Optional[torch.device]]: """ Convert source data to the same data type and device as the destination data. @@ -281,22 +363,37 @@ def convert_to_dst_type( dtype: an optional argument if the target `dtype` is different from the original `dst`'s data type. wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`. + drop_meta: whether to drop the meta information of the input data, default to `True`. + If `True`, then the meta information will be dropped quietly, unless the output type is MetaTensor. + If `False`, converting a MetaTensor into a non-tensor instance will raise an error. See Also: :func:`convert_data_type` """ + device = dst.device if isinstance(dst, torch.Tensor) else None if dtype is None: dtype = dst.dtype + copy_meta = False output_type: Any - if isinstance(dst, torch.Tensor): + if isinstance(dst, monai.data.MetaTensor): + output_type = monai.data.MetaTensor + if not isinstance(src, monai.data.MetaTensor): + copy_meta = True # converting a non-meta tensor to a meta tensor, probably take the metadata as well. + elif isinstance(dst, torch.Tensor): output_type = torch.Tensor elif isinstance(dst, np.ndarray): output_type = np.ndarray else: output_type = type(dst) - return convert_data_type(data=src, output_type=output_type, device=device, dtype=dtype, wrap_sequence=wrap_sequence) + output: NdarrayTensor + output, _type, _device = convert_data_type( + data=src, output_type=output_type, device=device, dtype=dtype, wrap_sequence=wrap_sequence, drop_meta=drop_meta + ) + if copy_meta and isinstance(output, monai.data.MetaTensor): # type: ignore + output.meta, output.applied_operations = deepcopy(dst.meta), deepcopy(dst.applied_operations) # type: ignore + return output, _type, _device def convert_to_list(data: Union[Sequence, torch.Tensor, np.ndarray]) -> list: diff --git a/runtests.sh b/runtests.sh index a69c408e3c..a632e2664f 100755 --- a/runtests.sh +++ b/runtests.sh @@ -373,7 +373,6 @@ then clang_format echo "${green}done!${noColor}" - exit fi # unconditionally report on the state of monai diff --git a/tests/profile_subclass/README.md b/tests/profile_subclass/README.md new file mode 100644 index 0000000000..de16ef2d91 --- /dev/null +++ b/tests/profile_subclass/README.md @@ -0,0 +1,43 @@ +# Profiling the performance of subclassing/`__torch_function__` in MONAI + +## Requirements +```bash +pip install py-spy +pip install snakeviz # for viewing the cProfile results +``` + +## Commands + +### Install MONAI +``` +./runtests.sh --build # from monai's root directory +``` +or follow the installation guide (https://docs.monai.io/en/latest/installation.html) + +### Profiling the task of adding two MetaTensors +```bash +python profiling.py +``` + +### Profiling using `py-spy` +```bash +py-spy record -o Tensor.svg -- python pyspy_profiling.py Tensor +py-spy record -o SubTensor.svg -- python pyspy_profiling.py SubTensor +py-spy record -o SubWithTorchFunc.svg -- python pyspy_profiling.py SubWithTorchFunc +py-spy record -o MetaTensor.svg -- python pyspy_profiling.py MetaTensor +``` + +### Profiling using `cProfile` and `SNAKEVIZ` + +```bash +python cprofile_profiling.py +snakeviz out_200.prof +``` + +--- +These tests are based on the following code: +https://github.com/pytorch/pytorch/tree/v1.11.0/benchmarks/overrides_benchmark + +- Overhead for torch functions when run on `torch.Tensor` objects is on the order of 2 microseconds. +- `__torch_function__` should add zero overhead for `torch.Tensor` inputs, a small overhead for subclasses of `torch.Tensor`, and an order of microseconds for `MeatTensor`. +- Changing the dispatching mechanism may result in changes that are on the order of 100 ns, which are hard to detect due to noise, but important. diff --git a/tests/profile_subclass/cprofile_profiling.py b/tests/profile_subclass/cprofile_profiling.py new file mode 100644 index 0000000000..a6c940c9c0 --- /dev/null +++ b/tests/profile_subclass/cprofile_profiling.py @@ -0,0 +1,28 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Profiling MetaTensor +""" + +import cProfile + +import torch + +from monai.data.meta_tensor import MetaTensor + +if __name__ == "__main__": + n_chan = 3 + for hwd in (10, 200): + shape = (n_chan, hwd, hwd, hwd) + a = MetaTensor(torch.rand(shape), meta={"affine": torch.eye(4) * 2, "fname": "something1"}) + b = MetaTensor(torch.rand(shape), meta={"affine": torch.eye(4) * 3, "fname": "something2"}) + cProfile.run("c = a + b", filename=f"out_{hwd}.prof") diff --git a/tests/profile_subclass/min_classes.py b/tests/profile_subclass/min_classes.py new file mode 100644 index 0000000000..87c0ce671d --- /dev/null +++ b/tests/profile_subclass/min_classes.py @@ -0,0 +1,29 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Minimal subclassing as baselines +Adapted from https://github.com/pytorch/pytorch/tree/v1.11.0/benchmarks/overrides_benchmark +""" + +import torch + +__all__ = ["SubTensor", "SubWithTorchFunc"] + + +class SubTensor(torch.Tensor): + pass + + +class SubWithTorchFunc(torch.Tensor): + def __torch_function__(self, func, types, args=(), kwargs=None): + return super().__torch_function__(func, types, args, {} if kwargs is None else kwargs) diff --git a/tests/profile_subclass/profiling.py b/tests/profile_subclass/profiling.py new file mode 100644 index 0000000000..28740e82e1 --- /dev/null +++ b/tests/profile_subclass/profiling.py @@ -0,0 +1,72 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Comparing torch.Tensor, SubTensor, SubWithTorchFunc, MetaTensor +Adapted from https://github.com/pytorch/pytorch/tree/v1.11.0/benchmarks/overrides_benchmark +""" +import argparse + +import torch +from min_classes import SubTensor, SubWithTorchFunc + +from monai.data import MetaTensor +from monai.utils.profiling import PerfContext + +NUM_REPEATS = 1000 +NUM_REPEAT_OF_REPEATS = 1000 + + +def bench(t1, t2): + bench_times = [] + for _ in range(NUM_REPEAT_OF_REPEATS): + with PerfContext() as pc: + for _ in range(NUM_REPEATS): + torch.add(t1, t2) + bench_times.append(pc.total_time) + + bench_time_min = float(torch.min(torch.Tensor(bench_times))) / NUM_REPEATS + bench_time_avg = float(torch.sum(torch.Tensor(bench_times))) / (NUM_REPEATS * NUM_REPEAT_OF_REPEATS) + bench_time_med = float(torch.median(torch.Tensor(bench_times))) / NUM_REPEATS + bench_std = float(torch.std(torch.Tensor(bench_times))) / NUM_REPEATS + return bench_time_min, bench_time_avg, bench_time_med, bench_std + + +def main(): + global NUM_REPEATS + global NUM_REPEAT_OF_REPEATS + + parser = argparse.ArgumentParser(description="Run the __torch_function__ benchmarks.") + parser.add_argument( + "--nreps", "-n", type=int, default=NUM_REPEATS, help="The number of repeats for one measurement." + ) + parser.add_argument("--nrepreps", "-m", type=int, default=NUM_REPEAT_OF_REPEATS, help="The number of measurements.") + args = parser.parse_args() + + NUM_REPEATS = args.nreps + NUM_REPEAT_OF_REPEATS = args.nrepreps + + types = torch.Tensor, SubTensor, SubWithTorchFunc, MetaTensor + + for t in types: + tensor_1 = t(1) + tensor_2 = t(2) + + b_min, b_avg, b_med, b_std = bench(tensor_1, tensor_2) + print( + "Type {} time (microseconds): min: {}, avg: {}, median: {}, and std {}.".format( + t.__name__, (10**6 * b_min), (10**6) * b_avg, (10**6) * b_med, (10**6) * b_std + ) + ) + + +if __name__ == "__main__": + main() diff --git a/tests/profile_subclass/pyspy_profiling.py b/tests/profile_subclass/pyspy_profiling.py new file mode 100644 index 0000000000..302bfd39c3 --- /dev/null +++ b/tests/profile_subclass/pyspy_profiling.py @@ -0,0 +1,40 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +To be used with py-spy, comparing torch.Tensor, SubTensor, SubWithTorchFunc, MetaTensor +Adapted from https://github.com/pytorch/pytorch/tree/v1.11.0/benchmarks/overrides_benchmark +""" +import argparse + +import torch +from min_classes import SubTensor, SubWithTorchFunc # noqa: F401 + +from monai.data import MetaTensor # noqa: F401 + +Tensor = torch.Tensor + +NUM_REPEATS = 1000000 + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run the torch.add for a given class a given number of times.") + parser.add_argument("tensor_class", metavar="TensorClass", type=str, help="The class to benchmark.") + parser.add_argument("--nreps", "-n", type=int, default=NUM_REPEATS, help="The number of repeats.") + args = parser.parse_args() + + TensorClass = globals()[args.tensor_class] + NUM_REPEATS = args.nreps + + t1 = TensorClass(1) + t2 = TensorClass(2) + + for _ in range(NUM_REPEATS): + torch.add(t1, t2)