Skip to content

Commit 9c00133

Browse files
authored
Merge branch 'dev' into feature/global-coords-spatial-crop
2 parents 060c4e1 + 65beb58 commit 9c00133

75 files changed

Lines changed: 3162 additions & 252 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/pythonapp.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ jobs:
8282
find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \;
8383
- name: Install the dependencies
8484
run: |
85-
python -m pip install --user --upgrade pip wheel
85+
python -m pip install --user --upgrade pip wheel pybind11
8686
python -m pip install torch==2.5.1 torchvision==0.20.1
8787
cat "requirements-dev.txt"
8888
python -m pip install --no-build-isolation -r requirements-dev.txt

MANIFEST.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@ include monai/_version.py
33

44
include README.md
55
include LICENSE
6+
7+
prune tests

docs/source/losses.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,11 @@ Segmentation Losses
9898
.. autoclass:: NACLLoss
9999
:members:
100100

101+
`MCCLoss`
102+
~~~~~~~~~
103+
.. autoclass:: MCCLoss
104+
:members:
105+
101106
Registration Losses
102107
-------------------
103108

docs/source/metrics.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,13 @@ Metrics
158158

159159
`Fréchet Inception Distance`
160160
------------------------------
161+
`Embedding Collapse`
162+
------------------------------
163+
.. autofunction:: compute_embedding_collapse
164+
165+
.. autoclass:: EmbeddingCollapseMetric
166+
:members:
167+
161168
.. autofunction:: compute_frechet_distance
162169

163170
.. autoclass:: FIDMetric

monai/apps/auto3dseg/auto_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def __init__(
229229
input = os.path.join(os.path.abspath(work_dir), "input.yaml")
230230
logger.info(f"Input config is not provided, using the default {input}")
231231

232-
self.data_src_cfg = dict()
232+
self.data_src_cfg = {}
233233
if isinstance(input, dict):
234234
self.data_src_cfg = input
235235
elif isinstance(input, str) and os.path.isfile(input):

monai/apps/detection/transforms/box_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def convert_box_to_mask(
267267
boxes_only_mask = np.ones(box_size, dtype=np.int16) * np.int16(labels_np[b])
268268
# apply to global mask
269269
slicing = [b]
270-
slicing.extend(slice(boxes_np[b, d], boxes_np[b, d + spatial_dims]) for d in range(spatial_dims)) # type:ignore
270+
slicing.extend(slice(boxes_np[b, d], boxes_np[b, d + spatial_dims]) for d in range(spatial_dims)) # type: ignore
271271
boxes_mask_np[tuple(slicing)] = boxes_only_mask
272272
return convert_to_dst_type(src=boxes_mask_np, dst=boxes, dtype=torch.int16)[0]
273273

monai/apps/detection/utils/anchor_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def grid_anchors(self, grid_sizes: list[list[int]], strides: list[list[Tensor]])
253253
# compute anchor centers regarding to the image.
254254
# shifts_centers is [x_center, y_center] or [x_center, y_center, z_center]
255255
shifts_centers = [
256-
torch.arange(0, size[axis], dtype=torch.int32, device=device) * stride[axis]
256+
torch.arange(0, size[axis], dtype=torch.int32, device=device) * stride[axis] + stride[axis] // 2
257257
for axis in range(self.spatial_dims)
258258
]
259259

monai/apps/nuclick/transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,14 +367,14 @@ def inclusion_map(self, mask, dtype):
367367

368368
def exclusion_map(self, others, dtype, jitter_range, drop_rate):
369369
point_mask = torch.zeros_like(others, dtype=dtype)
370-
if np.random.choice([True, False], p=[drop_rate, 1 - drop_rate]):
370+
if self.R.choice([True, False], p=[drop_rate, 1 - drop_rate]):
371371
return point_mask
372372

373373
max_x = point_mask.shape[0] - 1
374374
max_y = point_mask.shape[1] - 1
375375
stats = measure.regionprops(convert_to_numpy(others))
376376
for stat in stats:
377-
if np.random.choice([True, False], p=[drop_rate, 1 - drop_rate]):
377+
if self.R.choice([True, False], p=[drop_rate, 1 - drop_rate]):
378378
continue
379379

380380
# random jitter

monai/apps/reconstruction/transforms/dictionary.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from monai.apps.reconstruction.transforms.array import EquispacedKspaceMask, RandomKspaceMask
2121
from monai.config import DtypeLike, KeysCollection
2222
from monai.config.type_definitions import NdarrayOrTensor
23+
from monai.data.meta_tensor import MetaTensor
2324
from monai.transforms import InvertibleTransform
2425
from monai.transforms.croppad.array import SpatialCrop
2526
from monai.transforms.intensity.array import NormalizeIntensity
@@ -33,15 +34,36 @@ class ExtractDataKeyFromMetaKeyd(MapTransform):
3334
Moves keys from meta to data. It is useful when a dataset of paired samples
3435
is loaded and certain keys should be moved from meta to data.
3536
37+
This transform supports two modes:
38+
39+
1. When ``meta_key`` references a metadata dictionary in the data (e.g., when
40+
``image_only=False`` was used with ``LoadImaged``), the requested keys are
41+
extracted directly from that dictionary.
42+
43+
2. When ``meta_key`` references a ``MetaTensor`` in the data (e.g., when
44+
``image_only=True`` was used with ``LoadImaged``), the requested keys are
45+
extracted from its ``.meta`` attribute.
46+
3647
Args:
3748
keys: keys to be transferred from meta to data
38-
meta_key: the meta key where all the meta-data is stored
49+
meta_key: the key in the data dictionary where the metadata source is
50+
stored. This can be either a metadata dictionary or a ``MetaTensor``.
3951
allow_missing_keys: don't raise exception if key is missing
4052
4153
Example:
4254
When the fastMRI dataset is loaded, "kspace" is stored in the data dictionary,
4355
but the ground-truth image with the key "reconstruction_rss" is stored in the meta data.
4456
In this case, ExtractDataKeyFromMetaKeyd moves "reconstruction_rss" to data.
57+
58+
When ``LoadImaged`` is used with ``image_only=True`` (the default), the loaded
59+
data is a ``MetaTensor`` with metadata accessible via ``.meta``. In this case,
60+
set ``meta_key`` to the key of the ``MetaTensor`` itself::
61+
62+
li = LoadImaged(keys="image") # image_only=True by default
63+
dat = li({"image": "image.nii"})
64+
e = ExtractDataKeyFromMetaKeyd("filename_or_obj", meta_key="image")
65+
dat = e(dat)
66+
assert dat["image"].meta["filename_or_obj"] == dat["filename_or_obj"]
4567
"""
4668

4769
def __init__(self, keys: KeysCollection, meta_key: str, allow_missing_keys: bool = False) -> None:
@@ -58,9 +80,18 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, T
5880
the new data dictionary
5981
"""
6082
d = dict(data)
83+
meta_obj = d[self.meta_key]
84+
85+
# If meta_key references a MetaTensor, extract from its .meta attribute;
86+
# otherwise treat it as a metadata dictionary directly.
87+
if isinstance(meta_obj, MetaTensor):
88+
meta_dict: dict = meta_obj.meta
89+
else:
90+
meta_dict = dict(meta_obj)
91+
6192
for key in self.keys:
62-
if key in d[self.meta_key]:
63-
d[key] = d[self.meta_key][key] # type: ignore
93+
if key in meta_dict:
94+
d[key] = meta_dict[key] # type: ignore
6495
elif not self.allow_missing_keys:
6596
raise KeyError(
6697
f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the meta data"

monai/auto3dseg/analyzer.py

Lines changed: 54 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from abc import ABC, abstractmethod
1616
from collections.abc import Hashable, Mapping
1717
from copy import deepcopy
18-
from typing import Any
18+
from typing import Any, cast
1919

2020
import numpy as np
2121
import torch
@@ -105,7 +105,7 @@ def update_ops_nested_label(self, nested_key: str, op: Operations) -> None:
105105
raise ValueError("Nested_key input format is wrong. Please ensure it is like key1#0#key2")
106106
root: str
107107
child_key: str
108-
(root, _, child_key) = keys
108+
root, _, child_key = keys
109109
if root not in self.ops:
110110
self.ops[root] = [{}]
111111
self.ops[root][0].update({child_key: None})
@@ -216,50 +216,58 @@ def __init__(self, image_key: str, stats_name: str = DataStatsKeys.IMAGE_STATS)
216216
super().__init__(stats_name, report_format)
217217
self.update_ops(ImageStatsKeys.INTENSITY, SampleOperations())
218218

219+
@torch.no_grad()
219220
def __call__(self, data):
220-
# Input Validation Addition
221-
if not isinstance(data, dict):
222-
raise TypeError(f"Input data must be a dict, but got {type(data).__name__}.")
223-
if self.image_key not in data:
224-
raise KeyError(f"Key '{self.image_key}' not found in input data.")
225-
image = data[self.image_key]
226-
if not isinstance(image, (np.ndarray, torch.Tensor, MetaTensor)):
227-
raise TypeError(
228-
f"Value for '{self.image_key}' must be a numpy array, torch.Tensor, or MetaTensor, "
229-
f"but got {type(image).__name__}."
230-
)
231-
if image.ndim < 3:
232-
raise ValueError(
233-
f"Image data under '{self.image_key}' must have at least 3 dimensions, but got shape {image.shape}."
234-
)
235-
# --- End of validation ---
236221
"""
237-
Callable to execute the pre-defined functions
222+
Callable to execute the pre-defined functions.
238223
239224
Returns:
240225
A dictionary. The dict has the key in self.report_format. The value of
241226
ImageStatsKeys.INTENSITY is in a list format. Each element of the value list
242227
has stats pre-defined by SampleOperations (max, min, ....).
243228
244229
Raises:
245-
RuntimeError if the stats report generated is not consistent with the pre-
230+
KeyError: if ``self.image_key`` is not present in the input data.
231+
TypeError: if the input data is not a dictionary, or if the image value is
232+
not a numpy array, torch.Tensor, or MetaTensor.
233+
ValueError: if the image has fewer than 3 dimensions, or if pre-computed
234+
``nda_croppeds`` is not a list/tuple with one entry per image channel.
235+
RuntimeError: if the stats report generated is not consistent with the pre-
246236
defined report_format.
247237
248238
Note:
249239
The stats operation uses numpy and torch to compute max, min, and other
250240
functions. If the input has nan/inf, the stats results will be nan/inf.
251241
252242
"""
243+
if not isinstance(data, dict):
244+
raise TypeError(f"Input data must be a dict, but got {type(data).__name__}.")
245+
if self.image_key not in data:
246+
raise KeyError(f"Key '{self.image_key}' not found in input data.")
247+
image = data[self.image_key]
248+
if not isinstance(image, (np.ndarray, torch.Tensor, MetaTensor)):
249+
raise TypeError(
250+
f"Value for '{self.image_key}' must be a numpy array, torch.Tensor, or MetaTensor, "
251+
f"but got {type(image).__name__}."
252+
)
253+
if image.ndim < 3:
254+
raise ValueError(
255+
f"Image data under '{self.image_key}' must have at least 3 dimensions, but got shape {image.shape}."
256+
)
257+
253258
d = dict(data)
254259
start = time.time()
255-
restore_grad_state = torch.is_grad_enabled()
256-
torch.set_grad_enabled(False)
257-
258260
ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
259-
if "nda_croppeds" not in d:
261+
if "nda_croppeds" in d:
262+
nda_croppeds = d["nda_croppeds"]
263+
if not isinstance(nda_croppeds, (list, tuple)) or len(nda_croppeds) != len(ndas):
264+
raise ValueError(
265+
"Pre-computed 'nda_croppeds' must be a list or tuple with one entry per image channel "
266+
f"(expected {len(ndas)})."
267+
)
268+
else:
260269
nda_croppeds = [get_foreground_image(nda) for nda in ndas]
261270

262-
# perform calculation
263271
report = deepcopy(self.get_report_format())
264272

265273
report[ImageStatsKeys.SHAPE] = [list(nda.shape) for nda in ndas]
@@ -284,7 +292,6 @@ def __call__(self, data):
284292

285293
d[self.stats_name] = report
286294

287-
torch.set_grad_enabled(restore_grad_state)
288295
logger.debug(f"Get image stats spent {time.time() - start}")
289296
return d
290297

@@ -321,6 +328,7 @@ def __init__(self, image_key: str, label_key: str, stats_name: str = DataStatsKe
321328
super().__init__(stats_name, report_format)
322329
self.update_ops(ImageStatsKeys.INTENSITY, SampleOperations())
323330

331+
@torch.no_grad()
324332
def __call__(self, data: Mapping) -> dict:
325333
"""
326334
Callable to execute the pre-defined functions
@@ -341,9 +349,6 @@ def __call__(self, data: Mapping) -> dict:
341349

342350
d = dict(data)
343351
start = time.time()
344-
restore_grad_state = torch.is_grad_enabled()
345-
torch.set_grad_enabled(False)
346-
347352
ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
348353
ndas_label = d[self.label_key] # (H,W,D)
349354

@@ -353,7 +358,6 @@ def __call__(self, data: Mapping) -> dict:
353358
nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas]
354359
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]
355360

356-
# perform calculation
357361
report = deepcopy(self.get_report_format())
358362

359363
report[ImageStatsKeys.INTENSITY] = [
@@ -365,7 +369,6 @@ def __call__(self, data: Mapping) -> dict:
365369

366370
d[self.stats_name] = report
367371

368-
torch.set_grad_enabled(restore_grad_state)
369372
logger.debug(f"Get foreground image stats spent {time.time() - start}")
370373
return d
371374

@@ -418,6 +421,7 @@ def __init__(
418421
id_seq = ID_SEP_KEY.join([LabelStatsKeys.LABEL, "0", LabelStatsKeys.IMAGE_INTST])
419422
self.update_ops_nested_label(id_seq, SampleOperations())
420423

424+
@torch.no_grad()
421425
def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTensor | dict]:
422426
"""
423427
Callable to execute the pre-defined functions.
@@ -468,21 +472,31 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
468472
"""
469473
d: dict[Hashable, MetaTensor] = dict(data)
470474
start = time.time()
471-
if isinstance(d[self.image_key], (torch.Tensor, MetaTensor)) and d[self.image_key].device.type == "cuda":
472-
using_cuda = True
473-
else:
474-
using_cuda = False
475-
restore_grad_state = torch.is_grad_enabled()
476-
torch.set_grad_enabled(False)
475+
image_tensor = d[self.image_key]
476+
label_tensor = d[self.label_key]
477+
using_cuda = any(
478+
isinstance(t, (torch.Tensor, MetaTensor)) and t.device.type == "cuda" for t in (image_tensor, label_tensor)
479+
)
477480

478-
ndas: list[MetaTensor] = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] # type: ignore
479-
ndas_label: MetaTensor = d[self.label_key].astype(torch.int16) # (H,W,D)
481+
if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance(
482+
label_tensor, (MetaTensor, torch.Tensor)
483+
):
484+
if label_tensor.device != image_tensor.device:
485+
if using_cuda:
486+
cuda_device = image_tensor.device if image_tensor.device.type == "cuda" else label_tensor.device
487+
image_tensor = cast(MetaTensor, image_tensor.to(cuda_device))
488+
label_tensor = cast(MetaTensor, label_tensor.to(cuda_device))
489+
else:
490+
label_tensor = cast(MetaTensor, label_tensor.to(image_tensor.device))
491+
492+
ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore
493+
ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D)
480494

481495
if ndas_label.shape != ndas[0].shape:
482496
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")
483497

484498
nda_foregrounds: list[torch.Tensor] = [get_foreground_label(nda, ndas_label) for nda in ndas]
485-
nda_foregrounds = [nda if nda.numel() > 0 else torch.Tensor([0]) for nda in nda_foregrounds]
499+
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]
486500

487501
unique_label = unique(ndas_label)
488502
if isinstance(ndas_label, (MetaTensor, torch.Tensor)):
@@ -534,7 +548,6 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
534548

535549
d[self.stats_name] = report # type: ignore[assignment]
536550

537-
torch.set_grad_enabled(restore_grad_state)
538551
logger.debug(f"Get label stats spent {time.time() - start}")
539552
return d # type: ignore[return-value]
540553

0 commit comments

Comments
 (0)