Skip to content

Commit 8b5ddf5

Browse files
authored
MetaTensor non-breaking changes (#4539)
* MetaTensor non-breaking changes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com>
1 parent 29a724d commit 8b5ddf5

20 files changed

Lines changed: 553 additions & 126 deletions

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,14 @@ tests/testing_data/MedNIST*
131131
tests/testing_data/*Hippocampus*
132132
tests/testing_data/*.tiff
133133
tests/testing_data/schema.json
134+
*.svg
134135

135136
# clang format tool
136137
.clang-format-bin/
137138

138139
# VSCode
139140
.vscode/
140141
*.zip
142+
143+
# profiling results
144+
*.prof

monai/data/__init__.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
from .thread_buffer import ThreadBuffer, ThreadDataLoader
7272
from .torchscript_utils import load_net_with_metadata, save_net_with_metadata
7373
from .utils import (
74+
affine_to_spacing,
7475
compute_importance_map,
7576
compute_shape_offset,
7677
convert_tables_to_dicts,
@@ -111,16 +112,22 @@
111112
from multiprocessing.reduction import ForkingPickler
112113

113114
def _rebuild_meta(cls, storage, metadata):
114-
storage_offset, size, stride, meta_obj = metadata
115-
t = cls([], meta=meta_obj, dtype=storage.dtype, device=storage.device)
115+
storage_offset, size, stride, meta_obj, applied_operations = metadata
116+
t = cls([], meta=meta_obj, applied_operations=applied_operations, dtype=storage.dtype, device=storage.device)
116117
t.set_(storage._untyped() if hasattr(storage, "_untyped") else storage, storage_offset, size, stride)
117118
return t
118119

119120
def reduce_meta_tensor(meta_tensor):
120121
storage = meta_tensor.storage()
121122
if storage.is_cuda:
122123
raise NotImplementedError("sharing CUDA metatensor across processes not implemented")
123-
metadata = (meta_tensor.storage_offset(), meta_tensor.size(), meta_tensor.stride(), meta_tensor.meta)
124+
metadata = (
125+
meta_tensor.storage_offset(),
126+
meta_tensor.size(),
127+
meta_tensor.stride(),
128+
meta_tensor.meta,
129+
meta_tensor.applied_operations,
130+
)
124131
return _rebuild_meta, (type(meta_tensor), storage, metadata)
125132

126133
ForkingPickler.register(MetaTensor, reduce_meta_tensor)

monai/data/dataset_summary.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
from monai.config import KeysCollection
1919
from monai.data.dataloader import DataLoader
2020
from monai.data.dataset import Dataset
21+
from monai.data.utils import affine_to_spacing
2122
from monai.transforms import concatenate
22-
from monai.utils import convert_data_type
23-
from monai.utils.enums import PostFix
23+
from monai.utils import PostFix, convert_data_type
2424

2525
DEFAULT_POST_FIX = PostFix.meta()
2626

@@ -84,7 +84,7 @@ def collect_meta_data(self):
8484
raise ValueError(f"To collect metadata for the dataset, key `{self.meta_key}` must exist in `data`.")
8585
self.all_meta_data.append(data[self.meta_key])
8686

87-
def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: int = 3, percentile: float = 10.0):
87+
def get_target_spacing(self, spacing_key: str = "affine", anisotropic_threshold: int = 3, percentile: float = 10.0):
8888
"""
8989
Calculate the target spacing according to all spacings.
9090
If the target spacing is very anisotropic,
@@ -93,7 +93,7 @@ def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold:
9393
After loading with `monai.DataLoader`, "pixdim" is in the form of `torch.Tensor` with size `(batch_size, 8)`.
9494
9595
Args:
96-
spacing_key: key of spacing in metadata (default: ``pixdim``).
96+
spacing_key: key of the affine used to compute spacing in metadata (default: ``affine``).
9797
anisotropic_threshold: threshold to decide if the target spacing is anisotropic (default: ``3``).
9898
percentile: for anisotropic target spacing, use the percentile of all spacings of the anisotropic axis to
9999
replace that axis.
@@ -103,7 +103,8 @@ def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold:
103103
self.collect_meta_data()
104104
if spacing_key not in self.all_meta_data[0]:
105105
raise ValueError("The provided spacing_key is not in self.all_meta_data.")
106-
all_spacings = concatenate(to_cat=[data[spacing_key][:, 1:4] for data in self.all_meta_data], axis=0)
106+
spacings = [affine_to_spacing(data[spacing_key][0], 3)[None] for data in self.all_meta_data]
107+
all_spacings = concatenate(to_cat=spacings, axis=0)
107108
all_spacings, *_ = convert_data_type(data=all_spacings, output_type=np.ndarray, wrap_sequence=True)
108109

109110
target_spacing = np.median(all_spacings, axis=0)

monai/data/meta_obj.py

Lines changed: 58 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111

1212
from __future__ import annotations
1313

14+
import itertools
1415
from copy import deepcopy
15-
from typing import Any, Callable, Sequence
16+
from typing import Any, Iterable
17+
18+
from monai.utils.enums import TraceKeys
1619

1720
_TRACK_META = True
1821

@@ -72,85 +75,88 @@ class MetaObj:
7275
"""
7376

7477
def __init__(self):
75-
self._meta: dict = self.get_default_meta()
78+
self._meta: dict = MetaObj.get_default_meta()
79+
self._applied_operations: list = MetaObj.get_default_applied_operations()
7680
self._is_batch: bool = False
7781

7882
@staticmethod
79-
def flatten_meta_objs(args: Sequence[Any]) -> list[MetaObj]:
83+
def flatten_meta_objs(*args: Iterable):
8084
"""
81-
Recursively flatten input and return all instances of `MetaObj` as a single
82-
list. This means that for both `torch.add(a, b)`, `torch.stack([a, b])` (and
85+
Recursively flatten input and yield all instances of `MetaObj`.
86+
This means that for both `torch.add(a, b)`, `torch.stack([a, b])` (and
8387
their numpy equivalents), we return `[a, b]` if both `a` and `b` are of type
8488
`MetaObj`.
8589
8690
Args:
87-
args: Sequence of inputs to be flattened.
91+
args: Iterables of inputs to be flattened.
8892
Returns:
8993
list of nested `MetaObj` from input.
9094
"""
91-
out = []
92-
for a in args:
95+
for a in itertools.chain(*args):
9396
if isinstance(a, (list, tuple)):
94-
out += MetaObj.flatten_meta_objs(a)
97+
yield from MetaObj.flatten_meta_objs(a)
9598
elif isinstance(a, MetaObj):
96-
out.append(a)
97-
return out
99+
yield a
98100

99-
def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None:
101+
def _copy_attr(self, attributes: list[str], input_objs, defaults: list, deep_copy: bool) -> None:
100102
"""
101-
Copy an attribute from the first in a list of `MetaObj`. In the case of
103+
Copy attributes from the first in a list of `MetaObj`. In the case of
102104
`torch.add(a, b)`, both `a` and `b` could be `MetaObj` or something else, so
103105
check them all. Copy the first to `self`.
104106
105107
We also perform a deep copy of the data if desired.
106108
107109
Args:
108-
attribute: string corresponding to attribute to be copied (e.g., `meta`).
109-
input_objs: List of `MetaObj`. We'll copy the attribute from the first one
110+
attributes: a sequence of strings corresponding to attributes to be copied (e.g., `['meta']`).
111+
input_objs: an iterable of `MetaObj` instances. We'll copy the attribute from the first one
110112
that contains that particular attribute.
111-
default_fn: If none of `input_objs` have the attribute that we're
112-
interested in, then use this default function (e.g., `lambda: {}`.)
113-
deep_copy: Should the attribute be deep copied? See `_copy_meta`.
113+
defaults: If none of `input_objs` have the attribute that we're
114+
interested in, then use this default value/function (e.g., `lambda: {}`.)
115+
the defaults must be the same length as `attributes`.
116+
deep_copy: whether to deep copy the corresponding attribute.
114117
115118
Returns:
116119
Returns `None`, but `self` should be updated to have the copied attribute.
117120
"""
118-
attributes = [getattr(i, attribute) for i in input_objs if hasattr(i, attribute)]
119-
if len(attributes) > 0:
120-
val = attributes[0]
121-
if deep_copy:
122-
val = deepcopy(val)
123-
setattr(self, attribute, val)
124-
else:
125-
setattr(self, attribute, default_fn())
126-
127-
def _copy_meta(self, input_objs: list[MetaObj]) -> None:
121+
found = [False] * len(attributes)
122+
for i, (idx, a) in itertools.product(input_objs, enumerate(attributes)):
123+
if not found[idx] and hasattr(i, a):
124+
setattr(self, a, deepcopy(getattr(i, a)) if deep_copy else getattr(i, a))
125+
found[idx] = True
126+
if all(found):
127+
return
128+
for a, f, d in zip(attributes, found, defaults):
129+
if not f:
130+
setattr(self, a, d() if callable(defaults) else d)
131+
return
132+
133+
def _copy_meta(self, input_objs, deep_copy=False) -> None:
128134
"""
129-
Copy metadata from a list of `MetaObj`. For a given attribute, we copy the
135+
Copy metadata from an iterable of `MetaObj` instances. For a given attribute, we copy the
130136
adjunct data from the first element in the list containing that attribute.
131137
132-
If there has been a change in `id` (e.g., `a=b+c`), then deepcopy. Else (e.g.,
133-
`a+=1`), then don't.
134-
135138
Args:
136139
input_objs: list of `MetaObj` to copy data from.
137140
138141
"""
139-
id_in = id(input_objs[0]) if len(input_objs) > 0 else None
140-
deep_copy = id(self) != id_in
141-
self._copy_attr("meta", input_objs, self.get_default_meta, deep_copy)
142-
self._copy_attr("applied_operations", input_objs, self.get_default_applied_operations, deep_copy)
143-
self.is_batch = input_objs[0].is_batch if len(input_objs) > 0 else False
142+
self._copy_attr(
143+
["meta", "applied_operations"],
144+
input_objs,
145+
[MetaObj.get_default_meta(), MetaObj.get_default_applied_operations()],
146+
deep_copy,
147+
)
144148

145-
def get_default_meta(self) -> dict:
149+
@staticmethod
150+
def get_default_meta() -> dict:
146151
"""Get the default meta.
147152
148153
Returns:
149154
default metadata.
150155
"""
151156
return {}
152157

153-
def get_default_applied_operations(self) -> list:
158+
@staticmethod
159+
def get_default_applied_operations() -> list:
154160
"""Get the default applied operations.
155161
156162
Returns:
@@ -180,21 +186,29 @@ def __repr__(self) -> str:
180186
@property
181187
def meta(self) -> dict:
182188
"""Get the meta."""
183-
return self._meta
189+
return self._meta if hasattr(self, "_meta") else MetaObj.get_default_meta()
184190

185191
@meta.setter
186-
def meta(self, d: dict) -> None:
192+
def meta(self, d) -> None:
187193
"""Set the meta."""
194+
if d == TraceKeys.NONE:
195+
self._meta = MetaObj.get_default_meta()
188196
self._meta = d
189197

190198
@property
191199
def applied_operations(self) -> list:
192200
"""Get the applied operations."""
193-
return self._applied_operations
201+
if hasattr(self, "_applied_operations"):
202+
return self._applied_operations
203+
return MetaObj.get_default_applied_operations()
194204

195205
@applied_operations.setter
196-
def applied_operations(self, t: list) -> None:
206+
def applied_operations(self, t) -> None:
197207
"""Set the applied operations."""
208+
if t == TraceKeys.NONE:
209+
# received no operations when decollating a batch
210+
self._applied_operations = MetaObj.get_default_applied_operations()
211+
return
198212
self._applied_operations = t
199213

200214
def push_applied_operation(self, t: Any) -> None:
@@ -206,7 +220,7 @@ def pop_applied_operation(self) -> Any:
206220
@property
207221
def is_batch(self) -> bool:
208222
"""Return whether object is part of batch or not."""
209-
return self._is_batch
223+
return self._is_batch if hasattr(self, "_is_batch") else False
210224

211225
@is_batch.setter
212226
def is_batch(self, val: bool) -> None:

0 commit comments

Comments
 (0)