1515from abc import ABC , abstractmethod
1616from collections .abc import Hashable , Mapping
1717from copy import deepcopy
18- from typing import Any
18+ from typing import Any , cast
1919
2020import numpy as np
2121import torch
@@ -105,7 +105,7 @@ def update_ops_nested_label(self, nested_key: str, op: Operations) -> None:
105105 raise ValueError ("Nested_key input format is wrong. Please ensure it is like key1#0#key2" )
106106 root : str
107107 child_key : str
108- ( root , _ , child_key ) = keys
108+ root , _ , child_key = keys
109109 if root not in self .ops :
110110 self .ops [root ] = [{}]
111111 self .ops [root ][0 ].update ({child_key : None })
@@ -216,50 +216,58 @@ def __init__(self, image_key: str, stats_name: str = DataStatsKeys.IMAGE_STATS)
216216 super ().__init__ (stats_name , report_format )
217217 self .update_ops (ImageStatsKeys .INTENSITY , SampleOperations ())
218218
219+ @torch .no_grad ()
219220 def __call__ (self , data ):
220- # Input Validation Addition
221- if not isinstance (data , dict ):
222- raise TypeError (f"Input data must be a dict, but got { type (data ).__name__ } ." )
223- if self .image_key not in data :
224- raise KeyError (f"Key '{ self .image_key } ' not found in input data." )
225- image = data [self .image_key ]
226- if not isinstance (image , (np .ndarray , torch .Tensor , MetaTensor )):
227- raise TypeError (
228- f"Value for '{ self .image_key } ' must be a numpy array, torch.Tensor, or MetaTensor, "
229- f"but got { type (image ).__name__ } ."
230- )
231- if image .ndim < 3 :
232- raise ValueError (
233- f"Image data under '{ self .image_key } ' must have at least 3 dimensions, but got shape { image .shape } ."
234- )
235- # --- End of validation ---
236221 """
237- Callable to execute the pre-defined functions
222+ Callable to execute the pre-defined functions.
238223
239224 Returns:
240225 A dictionary. The dict has the key in self.report_format. The value of
241226 ImageStatsKeys.INTENSITY is in a list format. Each element of the value list
242227 has stats pre-defined by SampleOperations (max, min, ....).
243228
244229 Raises:
245- RuntimeError if the stats report generated is not consistent with the pre-
230+ KeyError: if ``self.image_key`` is not present in the input data.
231+ TypeError: if the input data is not a dictionary, or if the image value is
232+ not a numpy array, torch.Tensor, or MetaTensor.
233+ ValueError: if the image has fewer than 3 dimensions, or if pre-computed
234+ ``nda_croppeds`` is not a list/tuple with one entry per image channel.
235+ RuntimeError: if the stats report generated is not consistent with the pre-
246236 defined report_format.
247237
248238 Note:
249239 The stats operation uses numpy and torch to compute max, min, and other
250240 functions. If the input has nan/inf, the stats results will be nan/inf.
251241
252242 """
243+ if not isinstance (data , dict ):
244+ raise TypeError (f"Input data must be a dict, but got { type (data ).__name__ } ." )
245+ if self .image_key not in data :
246+ raise KeyError (f"Key '{ self .image_key } ' not found in input data." )
247+ image = data [self .image_key ]
248+ if not isinstance (image , (np .ndarray , torch .Tensor , MetaTensor )):
249+ raise TypeError (
250+ f"Value for '{ self .image_key } ' must be a numpy array, torch.Tensor, or MetaTensor, "
251+ f"but got { type (image ).__name__ } ."
252+ )
253+ if image .ndim < 3 :
254+ raise ValueError (
255+ f"Image data under '{ self .image_key } ' must have at least 3 dimensions, but got shape { image .shape } ."
256+ )
257+
253258 d = dict (data )
254259 start = time .time ()
255- restore_grad_state = torch .is_grad_enabled ()
256- torch .set_grad_enabled (False )
257-
258260 ndas = [d [self .image_key ][i ] for i in range (d [self .image_key ].shape [0 ])]
259- if "nda_croppeds" not in d :
261+ if "nda_croppeds" in d :
262+ nda_croppeds = d ["nda_croppeds" ]
263+ if not isinstance (nda_croppeds , (list , tuple )) or len (nda_croppeds ) != len (ndas ):
264+ raise ValueError (
265+ "Pre-computed 'nda_croppeds' must be a list or tuple with one entry per image channel "
266+ f"(expected { len (ndas )} )."
267+ )
268+ else :
260269 nda_croppeds = [get_foreground_image (nda ) for nda in ndas ]
261270
262- # perform calculation
263271 report = deepcopy (self .get_report_format ())
264272
265273 report [ImageStatsKeys .SHAPE ] = [list (nda .shape ) for nda in ndas ]
@@ -284,7 +292,6 @@ def __call__(self, data):
284292
285293 d [self .stats_name ] = report
286294
287- torch .set_grad_enabled (restore_grad_state )
288295 logger .debug (f"Get image stats spent { time .time () - start } " )
289296 return d
290297
@@ -321,6 +328,7 @@ def __init__(self, image_key: str, label_key: str, stats_name: str = DataStatsKe
321328 super ().__init__ (stats_name , report_format )
322329 self .update_ops (ImageStatsKeys .INTENSITY , SampleOperations ())
323330
331+ @torch .no_grad ()
324332 def __call__ (self , data : Mapping ) -> dict :
325333 """
326334 Callable to execute the pre-defined functions
@@ -341,9 +349,6 @@ def __call__(self, data: Mapping) -> dict:
341349
342350 d = dict (data )
343351 start = time .time ()
344- restore_grad_state = torch .is_grad_enabled ()
345- torch .set_grad_enabled (False )
346-
347352 ndas = [d [self .image_key ][i ] for i in range (d [self .image_key ].shape [0 ])]
348353 ndas_label = d [self .label_key ] # (H,W,D)
349354
@@ -353,7 +358,6 @@ def __call__(self, data: Mapping) -> dict:
353358 nda_foregrounds = [get_foreground_label (nda , ndas_label ) for nda in ndas ]
354359 nda_foregrounds = [nda if nda .numel () > 0 else MetaTensor ([0.0 ]) for nda in nda_foregrounds ]
355360
356- # perform calculation
357361 report = deepcopy (self .get_report_format ())
358362
359363 report [ImageStatsKeys .INTENSITY ] = [
@@ -365,7 +369,6 @@ def __call__(self, data: Mapping) -> dict:
365369
366370 d [self .stats_name ] = report
367371
368- torch .set_grad_enabled (restore_grad_state )
369372 logger .debug (f"Get foreground image stats spent { time .time () - start } " )
370373 return d
371374
@@ -418,6 +421,7 @@ def __init__(
418421 id_seq = ID_SEP_KEY .join ([LabelStatsKeys .LABEL , "0" , LabelStatsKeys .IMAGE_INTST ])
419422 self .update_ops_nested_label (id_seq , SampleOperations ())
420423
424+ @torch .no_grad ()
421425 def __call__ (self , data : Mapping [Hashable , MetaTensor ]) -> dict [Hashable , MetaTensor | dict ]:
422426 """
423427 Callable to execute the pre-defined functions.
@@ -468,21 +472,31 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
468472 """
469473 d : dict [Hashable , MetaTensor ] = dict (data )
470474 start = time .time ()
471- if isinstance (d [self .image_key ], (torch .Tensor , MetaTensor )) and d [self .image_key ].device .type == "cuda" :
472- using_cuda = True
473- else :
474- using_cuda = False
475- restore_grad_state = torch .is_grad_enabled ()
476- torch .set_grad_enabled (False )
475+ image_tensor = d [self .image_key ]
476+ label_tensor = d [self .label_key ]
477+ using_cuda = any (
478+ isinstance (t , (torch .Tensor , MetaTensor )) and t .device .type == "cuda" for t in (image_tensor , label_tensor )
479+ )
477480
478- ndas : list [MetaTensor ] = [d [self .image_key ][i ] for i in range (d [self .image_key ].shape [0 ])] # type: ignore
479- ndas_label : MetaTensor = d [self .label_key ].astype (torch .int16 ) # (H,W,D)
481+ if isinstance (image_tensor , (MetaTensor , torch .Tensor )) and isinstance (
482+ label_tensor , (MetaTensor , torch .Tensor )
483+ ):
484+ if label_tensor .device != image_tensor .device :
485+ if using_cuda :
486+ cuda_device = image_tensor .device if image_tensor .device .type == "cuda" else label_tensor .device
487+ image_tensor = cast (MetaTensor , image_tensor .to (cuda_device ))
488+ label_tensor = cast (MetaTensor , label_tensor .to (cuda_device ))
489+ else :
490+ label_tensor = cast (MetaTensor , label_tensor .to (image_tensor .device ))
491+
492+ ndas : list [MetaTensor ] = [image_tensor [i ] for i in range (image_tensor .shape [0 ])] # type: ignore
493+ ndas_label : MetaTensor = label_tensor .astype (torch .int16 ) # (H,W,D)
480494
481495 if ndas_label .shape != ndas [0 ].shape :
482496 raise ValueError (f"Label shape { ndas_label .shape } is different from image shape { ndas [0 ].shape } " )
483497
484498 nda_foregrounds : list [torch .Tensor ] = [get_foreground_label (nda , ndas_label ) for nda in ndas ]
485- nda_foregrounds = [nda if nda .numel () > 0 else torch . Tensor ([ 0 ]) for nda in nda_foregrounds ]
499+ nda_foregrounds = [nda if nda .numel () > 0 else MetaTensor ([ 0. 0 ]) for nda in nda_foregrounds ]
486500
487501 unique_label = unique (ndas_label )
488502 if isinstance (ndas_label , (MetaTensor , torch .Tensor )):
@@ -534,7 +548,6 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
534548
535549 d [self .stats_name ] = report # type: ignore[assignment]
536550
537- torch .set_grad_enabled (restore_grad_state )
538551 logger .debug (f"Get label stats spent { time .time () - start } " )
539552 return d # type: ignore[return-value]
540553
0 commit comments