Skip to content

Commit ae8c6dd

Browse files
committed
small fix
1 parent 6090c1d commit ae8c6dd

1 file changed

Lines changed: 23 additions & 1 deletion

File tree

src/fairchem/core/units/mlip_unit/batch_server.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,24 @@
2727
from fairchem.core.units.mlip_unit import MLIPPredictUnit
2828

2929

30+
def _to_cpu(obj: Any) -> Any:
31+
"""
32+
Return a CPU-resident copy of ``obj`` so that the result can be
33+
deserialized on CPU-only Ray workers.
34+
35+
Uses ``torch.save`` + ``torch.load(map_location="cpu")`` which
36+
transparently handles arbitrary object graphs containing tensors,
37+
``nn.Module`` instances, OmegaConf containers, etc., without needing
38+
to walk and mutate the structure ourselves.
39+
"""
40+
import io
41+
42+
buf = io.BytesIO()
43+
torch.save(obj, buf)
44+
buf.seek(0)
45+
return torch.load(buf, map_location="cpu", weights_only=False)
46+
47+
3048
class BatchPredictServerMixin:
3149
"""
3250
Shared batched-inference logic mixed into Ray Serve deployment classes.
@@ -427,7 +445,11 @@ async def get_predict_unit_attribute(
427445
"""
428446
model_id = model_id or serve.get_multiplexed_model_id()
429447
predict_unit = await self.get_model(model_id)
430-
return getattr(predict_unit, attribute_name)
448+
attr = getattr(predict_unit, attribute_name)
449+
# Move any CUDA tensors to CPU before returning so callers (which
450+
# may be CPU-only Ray workers) can deserialize the result without
451+
# requiring CUDA.
452+
return _to_cpu(attr)
431453

432454
async def validate_atoms_data(self, atoms_info: dict, task_name: str) -> dict:
433455
"""

0 commit comments

Comments
 (0)