Skip to content

Commit 098a443

Browse files
emanuiloericspod
andauthored
Enable global coordinates in spatial crop transforms (#8794)
I ran into this while building an injury classification pipeline on MRI. My annotations come in physical coordinates and I kept having to manually convert them to voxel space whenever I changed the preprocessing spacing. Found this issue and saw the design was already agreed on, so I went ahead and implemented it. Follows the approach proposed by @ericspod in #8206. Resolves #8206 ## Changes - Adds `TransformPointsWorldToImaged` and `TransformPointsImageToWorldd` — thin subclasses of `ApplyTransformToPointsd` that hardcode the `invert_affine` direction - `SpatialCropd` now accepts string dictionary keys for `roi_center`, `roi_size`, `roi_start`, and `roi_end`. When a string is passed, the actual values are read from the data dict at call time instead of at init - This means you can pipe world-space annotations straight through coordinate conversion and into cropping without manual recalculation ```python from monai.transforms import Compose, TransformPointsWorldToImaged, SpatialCropd pipeline = Compose([ TransformPointsWorldToImaged(keys="roi_start", refer_keys="image"), TransformPointsWorldToImaged(keys="roi_end", refer_keys="image"), SpatialCropd(keys="image", roi_start="roi_start", roi_end="roi_end"), ]) ``` ## Design notes - When no string keys are passed, `SpatialCropd` takes the original code path — zero overhead for existing usage - The string-key path recreates a `SpatialCrop` on each `__call__` (stored on `self.cropper`), since slice computation must be deferred until the data dict is available. `inverse()` is overridden for this path to use `pop_transform(check=False)`, since the cropper identity changes between calls — the actual crop info is read from the MetaTensor's transform stack - Tensors from `ApplyTransformToPoints` (shape `(C, N, dims)`) get flattened and rounded to int via `torch.round` (banker's rounding) to avoid systematic bias ## Tests - 13 tests for `TransformPointsWorldToImaged` (correctness, equivalence with base class, inverse, error cases) - 7 tests for `TransformPointsImageToWorldd` - 11 tests for `SpatialCropd` string-key support (start/end, center/size, mixed params, tensor shapes, float rounding, missing keys, `requires_current_data`, multi-key, inverse) - 1 end-to-end integration test: `TransformPointsWorldToImaged` → `SpatialCropd` with world-space ROI - All 12 existing `SpatialCropd` tests pass unchanged - All 21 existing `ApplyTransformToPointsd` tests pass unchanged --------- Signed-off-by: Emanuilo Jovanovic <emanuilo.jovanovic@smartcat.io> Signed-off-by: Emanuilo Jovanovic <emanuilo.jovanovic@hotmail.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent 65beb58 commit 098a443

6 files changed

Lines changed: 546 additions & 11 deletions

File tree

monai/transforms/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,12 @@
676676
ToTensord,
677677
ToTensorD,
678678
ToTensorDict,
679+
TransformPointsImageToWorldd,
680+
TransformPointsImageToWorldD,
681+
TransformPointsImageToWorldDict,
682+
TransformPointsWorldToImaged,
683+
TransformPointsWorldToImageD,
684+
TransformPointsWorldToImageDict,
679685
Transposed,
680686
TransposeD,
681687
TransposeDict,

monai/transforms/croppad/dictionary.py

Lines changed: 142 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from collections.abc import Callable, Hashable, Mapping, Sequence
2121
from copy import deepcopy
22-
from typing import Any
22+
from typing import Any, Optional, Union, cast
2323

2424
import numpy as np
2525
import torch
@@ -50,7 +50,7 @@
5050
from monai.transforms.traits import LazyTrait, MultiSampleTrait
5151
from monai.transforms.transform import LazyTransform, MapTransform, Randomizable
5252
from monai.transforms.utils import is_positive
53-
from monai.utils import MAX_SEED, Method, PytorchPadMode, ensure_tuple_rep
53+
from monai.utils import MAX_SEED, Method, PytorchPadMode, TraceKeys, ensure_tuple_rep
5454

5555
__all__ = [
5656
"Padd",
@@ -431,17 +431,33 @@ class SpatialCropd(Cropd):
431431
- a spatial center and size
432432
- the start and end coordinates of the ROI
433433
434+
ROI parameters (``roi_center``, ``roi_size``, ``roi_start``, ``roi_end``) can also be specified as
435+
string dictionary keys. When a string is provided, the actual coordinate values are read from the
436+
data dictionary at call time. This enables pipelines where coordinates are computed by earlier
437+
transforms (e.g., :py:class:`monai.transforms.TransformPointsWorldToImaged`) and stored in the
438+
data dictionary under the given key.
439+
440+
Example::
441+
442+
from monai.transforms import Compose, TransformPointsWorldToImaged, SpatialCropd
443+
444+
pipeline = Compose([
445+
TransformPointsWorldToImaged(keys="roi_start", refer_keys="image"),
446+
TransformPointsWorldToImaged(keys="roi_end", refer_keys="image"),
447+
SpatialCropd(keys="image", roi_start="roi_start", roi_end="roi_end"),
448+
])
449+
434450
This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`
435451
for more information.
436452
"""
437453

438454
def __init__(
439455
self,
440456
keys: KeysCollection,
441-
roi_center: Sequence[int] | int | None = None,
442-
roi_size: Sequence[int] | int | None = None,
443-
roi_start: Sequence[int] | int | None = None,
444-
roi_end: Sequence[int] | int | None = None,
457+
roi_center: Sequence[int] | int | str | None = None,
458+
roi_size: Sequence[int] | int | str | None = None,
459+
roi_start: Sequence[int] | int | str | None = None,
460+
roi_end: Sequence[int] | int | str | None = None,
445461
roi_slices: Sequence[slice] | None = None,
446462
allow_missing_keys: bool = False,
447463
lazy: bool = False,
@@ -450,19 +466,134 @@ def __init__(
450466
Args:
451467
keys: keys of the corresponding items to be transformed.
452468
See also: :py:class:`monai.transforms.compose.MapTransform`
453-
roi_center: voxel coordinates for center of the crop ROI.
469+
roi_center: voxel coordinates for center of the crop ROI, or a string key to look up
470+
the coordinates from the data dictionary.
454471
roi_size: size of the crop ROI, if a dimension of ROI size is larger than image size,
455-
will not crop that dimension of the image.
456-
roi_start: voxel coordinates for start of the crop ROI.
472+
will not crop that dimension of the image. Can also be a string key.
473+
roi_start: voxel coordinates for start of the crop ROI, or a string key to look up
474+
the coordinates from the data dictionary.
457475
roi_end: voxel coordinates for end of the crop ROI, if a coordinate is out of image,
458-
use the end coordinate of image.
476+
use the end coordinate of image. Can also be a string key.
459477
roi_slices: list of slices for each of the spatial dimensions.
460478
allow_missing_keys: don't raise exception if key is missing.
461479
lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False.
462480
"""
463-
cropper = SpatialCrop(roi_center, roi_size, roi_start, roi_end, roi_slices, lazy=lazy)
481+
self._roi_center = roi_center
482+
self._roi_size = roi_size
483+
self._roi_start = roi_start
484+
self._roi_end = roi_end
485+
self._roi_slices = roi_slices
486+
self._has_str_roi = any(isinstance(v, str) for v in [roi_center, roi_size, roi_start, roi_end])
487+
488+
if not self._has_str_roi:
489+
_roi_t = Optional[Union[Sequence[int], int]]
490+
cropper = SpatialCrop(
491+
cast(_roi_t, roi_center),
492+
cast(_roi_t, roi_size),
493+
cast(_roi_t, roi_start),
494+
cast(_roi_t, roi_end),
495+
roi_slices,
496+
lazy=lazy,
497+
)
498+
else:
499+
# Placeholder cropper for the string-key path. Replaced on self.cropper at
500+
# __call__ time once string keys are resolved from the data dictionary.
501+
cropper = SpatialCrop(roi_start=[0], roi_end=[1], lazy=lazy)
464502
super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys, lazy=lazy)
465503

504+
@staticmethod
505+
def _resolve_roi_param(val, d):
506+
"""Resolve an ROI parameter from the data dictionary if it is a string key.
507+
508+
Args:
509+
val: the ROI parameter value. If a string, it is used as a key to look up
510+
the actual value from ``d``. Otherwise returned as-is.
511+
d: the data dictionary.
512+
513+
Returns:
514+
The resolved ROI parameter. Tensors and numpy arrays are flattened to 1-D
515+
and rounded to int64 so they can be consumed by ``Crop.compute_slices``.
516+
517+
Raises:
518+
KeyError: if ``val`` is a string key that does not exist in ``d``.
519+
"""
520+
if not isinstance(val, str):
521+
return val
522+
if val not in d:
523+
raise KeyError(f"ROI key '{val}' not found in the data dictionary.")
524+
resolved = d[val]
525+
# ApplyTransformToPoints outputs tensors of shape (C, N, dims).
526+
# A single coordinate like [142.5, -67.3, 301.8] becomes shape (1, 1, 3).
527+
# Flatten to 1-D and round to integers for compute_slices.
528+
# Uses banker's rounding (torch.round) to avoid systematic bias in spatial coordinates.
529+
if isinstance(resolved, np.ndarray):
530+
resolved = torch.from_numpy(resolved)
531+
if isinstance(resolved, torch.Tensor):
532+
resolved = torch.round(resolved.flatten()).to(torch.int64)
533+
return resolved
534+
535+
@property
536+
def requires_current_data(self) -> bool:
537+
"""Returns True if ROI values are derived from dictionary members, False if constant members."""
538+
return self._has_str_roi
539+
540+
def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]:
541+
"""
542+
Args:
543+
data: dictionary of data items to be transformed.
544+
lazy: whether to execute lazily. If ``None``, uses the instance default.
545+
546+
Returns:
547+
Dictionary with cropped data for each key.
548+
"""
549+
if not self.requires_current_data:
550+
return super().__call__(data, lazy=lazy)
551+
552+
d = dict(data)
553+
roi_center = self._resolve_roi_param(self._roi_center, d)
554+
roi_size = self._resolve_roi_param(self._roi_size, d)
555+
roi_start = self._resolve_roi_param(self._roi_start, d)
556+
roi_end = self._resolve_roi_param(self._roi_end, d)
557+
558+
lazy_ = self.lazy if lazy is None else lazy
559+
cropper = SpatialCrop(
560+
roi_center=roi_center,
561+
roi_size=roi_size,
562+
roi_start=roi_start,
563+
roi_end=roi_end,
564+
roi_slices=self._roi_slices,
565+
lazy=lazy_,
566+
)
567+
for key in self.key_iterator(d):
568+
d[key] = cropper(d[key], lazy=lazy_)
569+
return d
570+
571+
def inverse(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTensor]:
572+
"""
573+
Inverse of the crop transform, restoring the original spatial dimensions via padding.
574+
575+
For the string-key path, the cropper used in ``__call__`` is a per-invocation local
576+
instance, so its ``id()`` won't match the one stored in the MetaTensor's transform stack.
577+
This override bypasses the ID check and applies the inverse directly using the crop info
578+
stored in the MetaTensor.
579+
580+
Args:
581+
data: dictionary of cropped ``MetaTensor`` items.
582+
583+
Returns:
584+
Dictionary with inverse-transformed (padded) data for each key.
585+
"""
586+
if not self.requires_current_data:
587+
return super().inverse(data)
588+
d = dict(data)
589+
for key in self.key_iterator(d):
590+
transform = self.cropper.pop_transform(d[key], check=False)
591+
cropped = transform[TraceKeys.EXTRA_INFO]["cropped"]
592+
inverse_transform = BorderPad(cropped)
593+
with inverse_transform.trace_transform(False):
594+
d[key] = inverse_transform(d[key]) # type: ignore[assignment]
595+
return d
596+
466597

467598
class CenterSpatialCropd(Cropd):
468599
"""

monai/transforms/utility/dictionary.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,12 @@
192192
"ApplyTransformToPointsd",
193193
"ApplyTransformToPointsD",
194194
"ApplyTransformToPointsDict",
195+
"TransformPointsWorldToImaged",
196+
"TransformPointsWorldToImageD",
197+
"TransformPointsWorldToImageDict",
198+
"TransformPointsImageToWorldd",
199+
"TransformPointsImageToWorldD",
200+
"TransformPointsImageToWorldDict",
195201
"FlattenSequenced",
196202
"FlattenSequenceD",
197203
"FlattenSequenceDict",
@@ -1918,6 +1924,86 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch
19181924
return d
19191925

19201926

1927+
class TransformPointsWorldToImaged(ApplyTransformToPointsd):
1928+
"""
1929+
Dictionary-based transform to convert points from world coordinates to image coordinates.
1930+
1931+
This is a convenience subclass of :py:class:`monai.transforms.ApplyTransformToPointsd` with
1932+
``invert_affine=True``, which transforms world-space coordinates into the coordinate space of a
1933+
reference image by inverting the image's affine matrix.
1934+
1935+
Args:
1936+
keys: keys of the corresponding items to be transformed.
1937+
See also: monai.transforms.MapTransform
1938+
refer_keys: The key of the reference image used to derive the affine transformation.
1939+
This is required because the affine must come from a reference image.
1940+
It can also be a sequence of keys, in which case each refers to the affine applied
1941+
to the matching points in ``keys``.
1942+
dtype: The desired data type for the output.
1943+
affine_lps_to_ras: Defaults to ``False``. Set to ``True`` if your point data is in the RAS
1944+
coordinate system or you're using ``ITKReader`` with ``affine_lps_to_ras=True``.
1945+
allow_missing_keys: Don't raise exception if key is missing.
1946+
"""
1947+
1948+
def __init__(
1949+
self,
1950+
keys: KeysCollection,
1951+
refer_keys: KeysCollection,
1952+
dtype: DtypeLike | torch.dtype = torch.float64,
1953+
affine_lps_to_ras: bool = False,
1954+
allow_missing_keys: bool = False,
1955+
):
1956+
super().__init__(
1957+
keys=keys,
1958+
refer_keys=refer_keys,
1959+
dtype=dtype,
1960+
affine=None,
1961+
invert_affine=True,
1962+
affine_lps_to_ras=affine_lps_to_ras,
1963+
allow_missing_keys=allow_missing_keys,
1964+
)
1965+
1966+
1967+
class TransformPointsImageToWorldd(ApplyTransformToPointsd):
1968+
"""
1969+
Dictionary-based transform to convert points from image coordinates to world coordinates.
1970+
1971+
This is a convenience subclass of :py:class:`monai.transforms.ApplyTransformToPointsd` with
1972+
``invert_affine=False``, which transforms image-space coordinates into world-space coordinates
1973+
by applying the reference image's affine matrix directly.
1974+
1975+
Args:
1976+
keys: keys of the corresponding items to be transformed.
1977+
See also: monai.transforms.MapTransform
1978+
refer_keys: The key of the reference image used to derive the affine transformation.
1979+
This is required because the affine must come from a reference image.
1980+
It can also be a sequence of keys, in which case each refers to the affine applied
1981+
to the matching points in ``keys``.
1982+
dtype: The desired data type for the output.
1983+
affine_lps_to_ras: Defaults to ``False``. Set to ``True`` if your point data is in the RAS
1984+
coordinate system or you're using ``ITKReader`` with ``affine_lps_to_ras=True``.
1985+
allow_missing_keys: Don't raise exception if key is missing.
1986+
"""
1987+
1988+
def __init__(
1989+
self,
1990+
keys: KeysCollection,
1991+
refer_keys: KeysCollection,
1992+
dtype: DtypeLike | torch.dtype = torch.float64,
1993+
affine_lps_to_ras: bool = False,
1994+
allow_missing_keys: bool = False,
1995+
):
1996+
super().__init__(
1997+
keys=keys,
1998+
refer_keys=refer_keys,
1999+
dtype=dtype,
2000+
affine=None,
2001+
invert_affine=False,
2002+
affine_lps_to_ras=affine_lps_to_ras,
2003+
allow_missing_keys=allow_missing_keys,
2004+
)
2005+
2006+
19212007
class FlattenSequenced(MapTransform, ReduceTrait):
19222008
"""
19232009
Dictionary-based wrapper of :py:class:`monai.transforms.FlattenSequence`.
@@ -1983,4 +2069,6 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
19832069
AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd
19842070
FlattenSubKeysD = FlattenSubKeysDict = FlattenSubKeysd
19852071
ApplyTransformToPointsD = ApplyTransformToPointsDict = ApplyTransformToPointsd
2072+
TransformPointsWorldToImageD = TransformPointsWorldToImageDict = TransformPointsWorldToImaged
2073+
TransformPointsImageToWorldD = TransformPointsImageToWorldDict = TransformPointsImageToWorldd
19862074
FlattenSequenceD = FlattenSequenceDict = FlattenSequenced

0 commit comments

Comments
 (0)