Skip to content

Commit 6090c1d

Browse files
zulissimetaCopilot
andcommitted
small fixes
Co-authored-by: Copilot <copilot@github.com>
1 parent 9e78107 commit 6090c1d

1 file changed

Lines changed: 14 additions & 0 deletions

File tree

src/fairchem/core/units/mlip_unit/batch_server.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,14 @@ def _split_predictions(
185185
f"and num_atoms {batch.num_atoms}"
186186
)
187187

188+
# Move to CPU before returning so the caller (which may be a
189+
# CPU-only Ray worker) can deserialize the result without
190+
# requiring CUDA.
191+
if hasattr(system_predictions[key], "detach"):
192+
system_predictions[key] = (
193+
system_predictions[key].detach().cpu()
194+
)
195+
188196
split_preds.append(system_predictions)
189197

190198
return split_preds
@@ -226,6 +234,9 @@ def __init__(
226234
"BatchPredictServer initialized with predict_unit from object store"
227235
)
228236

237+
def is_multiplexed(self) -> bool:
238+
return False
239+
229240

230241
@serve.deployment(
231242
logging_config=serve.schema.LoggingConfig(log_level="WARNING"),
@@ -270,6 +281,9 @@ def __init__(
270281
self.configure_batching(max_batch_size, batch_wait_timeout_s)
271282
logging.info("MultiplexedBatchPredictServer initialized")
272283

284+
def is_multiplexed(self) -> bool:
285+
return True
286+
273287
@serve.batch(
274288
batch_size_fn=lambda batch: sum(sample.natoms.sum() for sample in batch).item()
275289
)

0 commit comments

Comments
 (0)