File tree Expand file tree Collapse file tree
src/fairchem/core/units/mlip_unit Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -185,6 +185,14 @@ def _split_predictions(
185185 f"and num_atoms { batch .num_atoms } "
186186 )
187187
188+ # Move to CPU before returning so the caller (which may be a
189+ # CPU-only Ray worker) can deserialize the result without
190+ # requiring CUDA.
191+ if hasattr (system_predictions [key ], "detach" ):
192+ system_predictions [key ] = (
193+ system_predictions [key ].detach ().cpu ()
194+ )
195+
188196 split_preds .append (system_predictions )
189197
190198 return split_preds
@@ -226,6 +234,9 @@ def __init__(
226234 "BatchPredictServer initialized with predict_unit from object store"
227235 )
228236
237+ def is_multiplexed (self ) -> bool :
238+ return False
239+
229240
230241@serve .deployment (
231242 logging_config = serve .schema .LoggingConfig (log_level = "WARNING" ),
@@ -270,6 +281,9 @@ def __init__(
270281 self .configure_batching (max_batch_size , batch_wait_timeout_s )
271282 logging .info ("MultiplexedBatchPredictServer initialized" )
272283
284+ def is_multiplexed (self ) -> bool :
285+ return True
286+
273287 @serve .batch (
274288 batch_size_fn = lambda batch : sum (sample .natoms .sum () for sample in batch ).item ()
275289 )
You can’t perform that action at this time.
0 commit comments