File tree Expand file tree Collapse file tree
src/fairchem/core/units/mlip_unit Expand file tree Collapse file tree Original file line number Diff line number Diff line change 2727 from fairchem .core .units .mlip_unit import MLIPPredictUnit
2828
2929
30+ def _to_cpu (obj : Any ) -> Any :
31+ """
32+ Return a CPU-resident copy of ``obj`` so that the result can be
33+ deserialized on CPU-only Ray workers.
34+
35+ Uses ``torch.save`` + ``torch.load(map_location="cpu")`` which
36+ transparently handles arbitrary object graphs containing tensors,
37+ ``nn.Module`` instances, OmegaConf containers, etc., without needing
38+ to walk and mutate the structure ourselves.
39+ """
40+ import io
41+
42+ buf = io .BytesIO ()
43+ torch .save (obj , buf )
44+ buf .seek (0 )
45+ return torch .load (buf , map_location = "cpu" , weights_only = False )
46+
47+
3048class BatchPredictServerMixin :
3149 """
3250 Shared batched-inference logic mixed into Ray Serve deployment classes.
@@ -427,7 +445,11 @@ async def get_predict_unit_attribute(
427445 """
428446 model_id = model_id or serve .get_multiplexed_model_id ()
429447 predict_unit = await self .get_model (model_id )
430- return getattr (predict_unit , attribute_name )
448+ attr = getattr (predict_unit , attribute_name )
449+ # Move any CUDA tensors to CPU before returning so callers (which
450+ # may be CPU-only Ray workers) can deserialize the result without
451+ # requiring CUDA.
452+ return _to_cpu (attr )
431453
432454 async def validate_atoms_data (self , atoms_info : dict , task_name : str ) -> dict :
433455 """
You can’t perform that action at this time.
0 commit comments