1111
1212from __future__ import annotations
1313
14+ import itertools
1415from copy import deepcopy
15- from typing import Any , Callable , Sequence
16+ from typing import Any , Iterable
17+
18+ from monai .utils .enums import TraceKeys
1619
1720_TRACK_META = True
1821
@@ -72,85 +75,88 @@ class MetaObj:
7275 """
7376
7477 def __init__ (self ):
75- self ._meta : dict = self .get_default_meta ()
78+ self ._meta : dict = MetaObj .get_default_meta ()
79+ self ._applied_operations : list = MetaObj .get_default_applied_operations ()
7680 self ._is_batch : bool = False
7781
7882 @staticmethod
79- def flatten_meta_objs (args : Sequence [ Any ]) -> list [ MetaObj ] :
83+ def flatten_meta_objs (* args : Iterable ) :
8084 """
81- Recursively flatten input and return all instances of `MetaObj` as a single
82- list. This means that for both `torch.add(a, b)`, `torch.stack([a, b])` (and
85+ Recursively flatten input and yield all instances of `MetaObj`.
86+ This means that for both `torch.add(a, b)`, `torch.stack([a, b])` (and
8387 their numpy equivalents), we return `[a, b]` if both `a` and `b` are of type
8488 `MetaObj`.
8589
8690 Args:
87- args: Sequence of inputs to be flattened.
91+ args: Iterables of inputs to be flattened.
8892 Returns:
8993 list of nested `MetaObj` from input.
9094 """
91- out = []
92- for a in args :
95+ for a in itertools .chain (* args ):
9396 if isinstance (a , (list , tuple )):
94- out += MetaObj .flatten_meta_objs (a )
97+ yield from MetaObj .flatten_meta_objs (a )
9598 elif isinstance (a , MetaObj ):
96- out .append (a )
97- return out
99+ yield a
98100
99- def _copy_attr (self , attribute : str , input_objs : list [ MetaObj ], default_fn : Callable , deep_copy : bool ) -> None :
101+ def _copy_attr (self , attributes : list [ str ] , input_objs , defaults : list , deep_copy : bool ) -> None :
100102 """
101- Copy an attribute from the first in a list of `MetaObj`. In the case of
103+ Copy attributes from the first in a list of `MetaObj`. In the case of
102104 `torch.add(a, b)`, both `a` and `b` could be `MetaObj` or something else, so
103105 check them all. Copy the first to `self`.
104106
105107 We also perform a deep copy of the data if desired.
106108
107109 Args:
108- attribute: string corresponding to attribute to be copied (e.g., `meta`).
109- input_objs: List of `MetaObj`. We'll copy the attribute from the first one
110+ attributes: a sequence of strings corresponding to attributes to be copied (e.g., `[' meta'] `).
111+ input_objs: an iterable of `MetaObj` instances . We'll copy the attribute from the first one
110112 that contains that particular attribute.
111- default_fn: If none of `input_objs` have the attribute that we're
112- interested in, then use this default function (e.g., `lambda: {}`.)
113- deep_copy: Should the attribute be deep copied? See `_copy_meta`.
113+ defaults: If none of `input_objs` have the attribute that we're
114+ interested in, then use this default value/function (e.g., `lambda: {}`.)
115+ the defaults must be the same length as `attributes`.
116+ deep_copy: whether to deep copy the corresponding attribute.
114117
115118 Returns:
116119 Returns `None`, but `self` should be updated to have the copied attribute.
117120 """
118- attributes = [getattr (i , attribute ) for i in input_objs if hasattr (i , attribute )]
119- if len (attributes ) > 0 :
120- val = attributes [0 ]
121- if deep_copy :
122- val = deepcopy (val )
123- setattr (self , attribute , val )
124- else :
125- setattr (self , attribute , default_fn ())
126-
127- def _copy_meta (self , input_objs : list [MetaObj ]) -> None :
121+ found = [False ] * len (attributes )
122+ for i , (idx , a ) in itertools .product (input_objs , enumerate (attributes )):
123+ if not found [idx ] and hasattr (i , a ):
124+ setattr (self , a , deepcopy (getattr (i , a )) if deep_copy else getattr (i , a ))
125+ found [idx ] = True
126+ if all (found ):
127+ return
128+ for a , f , d in zip (attributes , found , defaults ):
129+ if not f :
130+ setattr (self , a , d () if callable (defaults ) else d )
131+ return
132+
133+ def _copy_meta (self , input_objs , deep_copy = False ) -> None :
128134 """
129- Copy metadata from a list of `MetaObj`. For a given attribute, we copy the
135+ Copy metadata from an iterable of `MetaObj` instances . For a given attribute, we copy the
130136 adjunct data from the first element in the list containing that attribute.
131137
132- If there has been a change in `id` (e.g., `a=b+c`), then deepcopy. Else (e.g.,
133- `a+=1`), then don't.
134-
135138 Args:
136139 input_objs: list of `MetaObj` to copy data from.
137140
138141 """
139- id_in = id (input_objs [0 ]) if len (input_objs ) > 0 else None
140- deep_copy = id (self ) != id_in
141- self ._copy_attr ("meta" , input_objs , self .get_default_meta , deep_copy )
142- self ._copy_attr ("applied_operations" , input_objs , self .get_default_applied_operations , deep_copy )
143- self .is_batch = input_objs [0 ].is_batch if len (input_objs ) > 0 else False
142+ self ._copy_attr (
143+ ["meta" , "applied_operations" ],
144+ input_objs ,
145+ [MetaObj .get_default_meta (), MetaObj .get_default_applied_operations ()],
146+ deep_copy ,
147+ )
144148
145- def get_default_meta (self ) -> dict :
149+ @staticmethod
150+ def get_default_meta () -> dict :
146151 """Get the default meta.
147152
148153 Returns:
149154 default metadata.
150155 """
151156 return {}
152157
153- def get_default_applied_operations (self ) -> list :
158+ @staticmethod
159+ def get_default_applied_operations () -> list :
154160 """Get the default applied operations.
155161
156162 Returns:
@@ -180,21 +186,29 @@ def __repr__(self) -> str:
180186 @property
181187 def meta (self ) -> dict :
182188 """Get the meta."""
183- return self ._meta
189+ return self ._meta if hasattr ( self , "_meta" ) else MetaObj . get_default_meta ()
184190
185191 @meta .setter
186- def meta (self , d : dict ) -> None :
192+ def meta (self , d ) -> None :
187193 """Set the meta."""
194+ if d == TraceKeys .NONE :
195+ self ._meta = MetaObj .get_default_meta ()
188196 self ._meta = d
189197
190198 @property
191199 def applied_operations (self ) -> list :
192200 """Get the applied operations."""
193- return self ._applied_operations
201+ if hasattr (self , "_applied_operations" ):
202+ return self ._applied_operations
203+ return MetaObj .get_default_applied_operations ()
194204
195205 @applied_operations .setter
196- def applied_operations (self , t : list ) -> None :
206+ def applied_operations (self , t ) -> None :
197207 """Set the applied operations."""
208+ if t == TraceKeys .NONE :
209+ # received no operations when decollating a batch
210+ self ._applied_operations = MetaObj .get_default_applied_operations ()
211+ return
198212 self ._applied_operations = t
199213
200214 def push_applied_operation (self , t : Any ) -> None :
@@ -206,7 +220,7 @@ def pop_applied_operation(self) -> Any:
206220 @property
207221 def is_batch (self ) -> bool :
208222 """Return whether object is part of batch or not."""
209- return self ._is_batch
223+ return self ._is_batch if hasattr ( self , "_is_batch" ) else False
210224
211225 @is_batch .setter
212226 def is_batch (self , val : bool ) -> None :
0 commit comments